direct.nn.mwcnn package#

Submodules#

direct.nn.mwcnn.mwcnn module#

class direct.nn.mwcnn.mwcnn.DWT[source]#

Bases: Module

2D Discrete Wavelet Transform as implemented in [1]_.

References:

__init__()[source]#

Inits DWT.

forward(x)[source]#

Computes DWT of input tensor.

Parameters:

x (Tensor) – Input tensor.

Return type:

Tensor

Returns:

DWT of input.

class direct.nn.mwcnn.mwcnn.IWT[source]#

Bases: Module

2D Inverse Wavelet Transform as implemented in [1]_.

References:

__init__()[source]#

Inits IWT.

forward(x)[source]#

Computes IWT of input tensor.

Parameters:

x (Tensor) – Input tensor.

Return type:

Tensor

Returns:

IWT of input.

class direct.nn.mwcnn.mwcnn.ConvBlock(in_channels, out_channels, kernel_size, bias=True, batchnorm=False, activation=nn.ReLU(True), scale=1.0)[source]#

Bases: Module

Convolution Block for MWCNN as implemented in [1]_.

References:

__init__(in_channels, out_channels, kernel_size, bias=True, batchnorm=False, activation=nn.ReLU(True), scale=1.0)[source]#

Inits ConvBlock.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (int) – Conv kernel size.

  • bias (bool) – Use convolution bias. Default: True.

  • batchnorm (bool) – Use batch normalization. Default: False.

  • activation (Module) – Activation function. Default: nn.ReLU(True).

  • scale (Optional[float]) – Scale. Default: 1.0.

forward(x)[source]#

Performs forward pass of ConvBlock.

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Input with)

Returns:

torch.Tensor Output with shape (N, C’, H’, W’).

Return type:

output

class direct.nn.mwcnn.mwcnn.DilatedConvBlock(in_channels, dilations, kernel_size, out_channels=None, bias=True, batchnorm=False, activation=nn.ReLU(True), scale=1.0)[source]#

Bases: Module

Double dilated Convolution Block fpr MWCNN as implemented in [1]_.

References:

__init__(in_channels, dilations, kernel_size, out_channels=None, bias=True, batchnorm=False, activation=nn.ReLU(True), scale=1.0)[source]#

Inits DilatedConvBlock.

Parameters:
  • in_channels (int) – Number of input channels.

  • dilations (Tuple[int, int]) – Number of dilations.

  • kernel_size (int) – Conv kernel size.

  • out_channels (Optional[int]) – Number of output channels.

  • bias (bool) – Use convolution bias. Default: True.

  • batchnorm (bool) – Use batch normalization. Default: False.

  • activation (Module) – Activation function. Default: nn.ReLU(True).

  • scale (Optional[float]) – Scale. Default: 1.0.

forward(x)[source]#

Performs forward pass of DilatedConvBlock.

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Input with)

Returns:

torch.Tensor Output with shape (N, C’, H’, W’).

Return type:

output

class direct.nn.mwcnn.mwcnn.MWCNN(input_channels, first_conv_hidden_channels, num_scales=4, bias=True, batchnorm=False, activation=nn.ReLU(True))[source]#

Bases: Module

Multi-level Wavelet CNN (MWCNN) implementation as implemented in [1]_.

References:

__init__(input_channels, first_conv_hidden_channels, num_scales=4, bias=True, batchnorm=False, activation=nn.ReLU(True))[source]#

Inits MWCNN.

Parameters:
  • input_channels (int) – Input channels dimension.

  • first_conv_hidden_channels (int) – First convolution output channels dimension.

  • num_scales (int) – Number of scales. Default: 4.

  • bias (bool) – Convolution bias. If True, adds a learnable bias to the output. Default: True.

  • batchnorm (bool) – If True, a batchnorm layer is added after each convolution. Default: False.

  • activation (Module) – Activation function applied after each convolution. Default: nn.ReLU().

static pad(x)[source]#
Return type:

Tensor

static crop_to_shape(x, shape)[source]#
Return type:

Tensor

forward(input_tensor, res=False)[source]#

Computes forward pass of MWCNN.

Parameters:
  • input_tensor (Tensor) – Input tensor.

  • res (bool) – If True, residual connection is applied to the output. Default: False.

Return type:

Tensor

Returns:

Output tensor.

Module contents#