direct.nn.mwcnn package#
Submodules#
direct.nn.mwcnn.mwcnn module#
- class direct.nn.mwcnn.mwcnn.ConvBlock(in_channels, out_channels, kernel_size, bias=True, batchnorm=False, activation=ReLU(inplace=True), scale=1.0)[source][source]#
Bases:
Module
Convolution Block for
MWCNN
as implemented in [1].References
[1]Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.
- forward(x)[source][source]#
Performs forward pass of
ConvBlock
.- Parameters:
- x: torch.Tensor
Input with shape (N, C, H, W).
- Returns:
- output: torch.Tensor
Output with shape (N, C’, H’, W’).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.mwcnn.mwcnn.DWT[source][source]#
Bases:
Module
2D Discrete Wavelet Transform as implemented in [1].
References
[1]Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.
- forward(x)[source][source]#
Computes DWT(x) given tensor x.
- Parameters:
- x: torch.Tensor
Input tensor.
- Returns:
- out: torch.Tensor
DWT of x.
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.mwcnn.mwcnn.DilatedConvBlock(in_channels, dilations, kernel_size, out_channels=None, bias=True, batchnorm=False, activation=ReLU(inplace=True), scale=1.0)[source][source]#
Bases:
Module
Double dilated Convolution Block fpr
MWCNN
as implemented in [1].References
[1]Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.
- forward(x)[source][source]#
Performs forward pass of
DilatedConvBlock
.- Parameters:
- x: torch.Tensor
Input with shape (N, C, H, W).
- Returns:
- output: torch.Tensor
Output with shape (N, C’, H’, W’).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.mwcnn.mwcnn.IWT[source][source]#
Bases:
Module
2D Inverse Wavelet Transform as implemented in [1].
References
[1]Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.
- forward(x)[source][source]#
Computes IWT(x) given tensor x.
- Parameters:
- x: torch.Tensor
Input tensor.
- Returns:
- h: torch.Tensor
IWT of x.
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.mwcnn.mwcnn.MWCNN(input_channels, first_conv_hidden_channels, num_scales=4, bias=True, batchnorm=False, activation=ReLU(inplace=True))[source][source]#
Bases:
Module
Multi-level Wavelet CNN (MWCNN) implementation as implemented in [1].
References
[1]Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.
- forward(input_tensor, res=False)[source][source]#
Computes forward pass of
MWCNN
.- Parameters:
- input_tensor: torch.Tensor
Input tensor.
- res: bool
If True, residual connection is applied to the output. Default: False.
- Returns:
- x: torch.Tensor
Output tensor.
- Return type:
Tensor
-
training:
bool
#