direct.nn.mwcnn package#

Submodules#

direct.nn.mwcnn.mwcnn module#

class direct.nn.mwcnn.mwcnn.ConvBlock(in_channels, out_channels, kernel_size, bias=True, batchnorm=False, activation=ReLU(inplace=True), scale=1.0)[source][source]#

Bases: Module

Convolution Block for MWCNN as implemented in [1].

References

[1]

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

forward(x)[source][source]#

Performs forward pass of ConvBlock.

Parameters:
x: torch.Tensor

Input with shape (N, C, H, W).

Returns:
output: torch.Tensor

Output with shape (N, C’, H’, W’).

Return type:

Tensor

training: bool#
class direct.nn.mwcnn.mwcnn.DWT[source][source]#

Bases: Module

2D Discrete Wavelet Transform as implemented in [1].

References

[1]

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

forward(x)[source][source]#

Computes DWT(x) given tensor x.

Parameters:
x: torch.Tensor

Input tensor.

Returns:
out: torch.Tensor

DWT of x.

Return type:

Tensor

training: bool#
class direct.nn.mwcnn.mwcnn.DilatedConvBlock(in_channels, dilations, kernel_size, out_channels=None, bias=True, batchnorm=False, activation=ReLU(inplace=True), scale=1.0)[source][source]#

Bases: Module

Double dilated Convolution Block fpr MWCNN as implemented in [1].

References

[1]

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

forward(x)[source][source]#

Performs forward pass of DilatedConvBlock.

Parameters:
x: torch.Tensor

Input with shape (N, C, H, W).

Returns:
output: torch.Tensor

Output with shape (N, C’, H’, W’).

Return type:

Tensor

training: bool#
class direct.nn.mwcnn.mwcnn.IWT[source][source]#

Bases: Module

2D Inverse Wavelet Transform as implemented in [1].

References

[1]

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

forward(x)[source][source]#

Computes IWT(x) given tensor x.

Parameters:
x: torch.Tensor

Input tensor.

Returns:
h: torch.Tensor

IWT of x.

Return type:

Tensor

training: bool#
class direct.nn.mwcnn.mwcnn.MWCNN(input_channels, first_conv_hidden_channels, num_scales=4, bias=True, batchnorm=False, activation=ReLU(inplace=True))[source][source]#

Bases: Module

Multi-level Wavelet CNN (MWCNN) implementation as implemented in [1].

References

[1]

Liu, Pengju, et al. “Multi-Level Wavelet-CNN for Image Restoration.” ArXiv:1805.07071 [Cs], May 2018. arXiv.org, http://arxiv.org/abs/1805.07071.

static crop_to_shape(x, shape)[source][source]#
Return type:

Tensor

forward(input_tensor, res=False)[source][source]#

Computes forward pass of MWCNN.

Parameters:
input_tensor: torch.Tensor

Input tensor.

res: bool

If True, residual connection is applied to the output. Default: False.

Returns:
x: torch.Tensor

Output tensor.

Return type:

Tensor

static pad(x)[source][source]#
Return type:

Tensor

training: bool#

Module contents#