direct.nn.unet package

Contents

direct.nn.unet package#

Submodules#

direct.nn.unet.config module#

class direct.nn.unet.config.NormUnetModel2dConfig(model_name='???', engine_name=None)[source][source]#

Bases: ModelConfig

dropout_probability: float = 0.0#
in_channels: int = 2#
norm_groups: int = 2#
num_filters: int = 16#
num_pool_layers: int = 4#
out_channels: int = 2#
class direct.nn.unet.config.Unet2dConfig(model_name: str = '???', engine_name: Optional[str] = None, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0, skip_connection: bool = False, normalized: bool = False, image_initialization: direct.nn.types.InitType = <InitType.ZERO_FILLED: 'zero_filled'>)[source][source]#

Bases: ModelConfig

dropout_probability: float = 0.0#
image_initialization: InitType = 'zero_filled'#
normalized: bool = False#
num_filters: int = 16#
num_pool_layers: int = 4#
skip_connection: bool = False#
class direct.nn.unet.config.UnetModel2dConfig(model_name: str = '???', engine_name: str | None = None, in_channels: int = 2, out_channels: int = 2, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0)[source][source]#

Bases: ModelConfig

dropout_probability: float = 0.0#
in_channels: int = 2#
num_filters: int = 16#
num_pool_layers: int = 4#
out_channels: int = 2#
class direct.nn.unet.config.UnetModel3dConfig(model_name: str = '???', engine_name: str | None = None, in_channels: int = 2, out_channels: int = 2, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0)[source][source]#

Bases: ModelConfig

dropout_probability: float = 0.0#
in_channels: int = 2#
num_filters: int = 16#
num_pool_layers: int = 4#
out_channels: int = 2#

direct.nn.unet.unet_2d module#

class direct.nn.unet.unet_2d.ConvBlock(in_channels, out_channels, dropout_probability)[source][source]#

Bases: Module

U-Net convolutional block.

It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout.

forward(input_data)[source][source]#

Performs the forward pass of ConvBlock.

Parameters:
input_data: torch.Tensor
Returns:
torch.Tensor
Return type:

Tensor

training: bool#
class direct.nn.unet.unet_2d.NormUnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source][source]#

Bases: Module

Implementation of a Normalized U-Net model.

forward(input_data)[source][source]#

Performs forward pass of NormUnetModel2d.

Parameters:
input_data: torch.Tensor
Returns:
torch.Tensor
Return type:

Tensor

static norm(input_data, groups)[source][source]#

Performs group normalization.

Return type:

Tuple[Tensor, Tensor, Tensor]

static pad(input_data)[source][source]#
Return type:

Tuple[Tensor, Tuple[List[int], List[int], int, int]]

training: bool#
static unnorm(input_data, mean, std, groups)[source][source]#
Return type:

Tensor

static unpad(input_data, h_pad, w_pad, h_mult, w_mult)[source][source]#
Return type:

Tensor

class direct.nn.unet.unet_2d.TransposeConvBlock(in_channels, out_channels)[source][source]#

Bases: Module

U-Net Transpose Convolutional Block.

It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation.

forward(input_data)[source][source]#

Performs forward pass of TransposeConvBlock.

Parameters:
input_data: torch.Tensor
Returns:
torch.Tensor
Return type:

Tensor

training: bool#
class direct.nn.unet.unet_2d.Unet2d(forward_operator, backward_operator, num_filters, num_pool_layers, dropout_probability, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED, **kwargs)[source][source]#

Bases: Module

PyTorch implementation of a U-Net model for MRI Reconstruction.

compute_sense_init(kspace, sensitivity_map)[source][source]#

Computes sense initialization \(x_{\text{SENSE}}\): :rtype: Tensor

\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]

where \(y^k\) denotes the data from coil \(k\).

Parameters:
kspace: torch.Tensor

k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

Returns:
input_image: torch.Tensor

Sense initialization \(x_{\text{SENSE}}\).

forward(masked_kspace, sensitivity_map=None)[source][source]#

Computes forward pass of Unet2d.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.

Returns:
output: torch.Tensor

Output image of shape (N, height, width, complex=2).

Return type:

Tensor

training: bool#
class direct.nn.unet.unet_2d.UnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source][source]#

Bases: Module

PyTorch implementation of a U-Net model based on [1].

References

[1]

Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab et al., Springer International Publishing, 2015, pp. 234–41. Springer Link, https://doi.org/10.1007/978-3-319-24574-4_28.

forward(input_data)[source][source]#

Performs forward pass of UnetModel2d.

Parameters:
input_data: torch.Tensor
Returns:
torch.Tensor
Return type:

Tensor

training: bool#

direct.nn.unet.unet_3d module#

Code for three-dimensional U-Net adapted from the 2D variant.

class direct.nn.unet.unet_3d.ConvBlock3D(in_channels, out_channels, dropout_probability)[source][source]#

Bases: Module

3D U-Net convolutional block.

forward(input_data)[source][source]#

Performs the forward pass of ConvBlock3D..

Parameters:
input_datatorch.Tensor

Input data.

Returns:
torch.Tensor
Return type:

Tensor

training: bool#
class direct.nn.unet.unet_3d.NormUnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source][source]#

Bases: Module

Implementation of a Normalized U-Net model for 3D data.

This is an extension to 3D volumes of direct.nn.unet.unet_2d.NormUnetModel2d.

forward(input_data)[source][source]#

Performs the forward pass of NormUnetModel3D.

Parameters:
input_datatorch.Tensor

Input tensor of shape (N, in_channels, slice/time, height, width).

Returns:
torch.Tensor

Output of shape (N, out_channels, slice/time, height, width).

Return type:

Tensor

static norm(input_data, groups)[source][source]#

Applies group normalization for 3D data.

Parameters:
input_datatorch.Tensor

The input tensor to normalize.

groupsint

The number of groups to divide the tensor into for normalization.

Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]

A tuple containing the normalized tensor, the mean, and the standard deviation used for normalization.

Return type:

tuple[Tensor, Tensor, Tensor]

static pad(input_data)[source][source]#

Applies padding to the input 3D tensor to ensure its dimensions are multiples of 16.

Parameters:
input_datatorch.Tensor

The input tensor to pad.

Returns:
tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]]

A tuple containing the padded tensor and a tuple with the padding applied to each dimension (height, width, depth) and the target dimensions after padding.

Return type:

tuple[Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]]

training: bool#
static unnorm(input_data, mean, std, groups)[source][source]#

Reverts the normalization applied to the 3D tensor.

Parameters:
input_datatorch.Tensor

The normalized tensor to revert normalization on.

meantorch.Tensor

The mean used during normalization.

stdtorch.Tensor

The standard deviation used during normalization.

groupsint

The number of groups the tensor was divided into during normalization.

Returns:
torch.Tensor

The tensor after reverting the normalization.

Return type:

Tensor

static unpad(input_data, h_pad, w_pad, z_pad, h_mult, w_mult, z_mult)[source][source]#

Removes padding from the 3D input tensor, reverting it to its original dimensions before padding was applied.

This method is typically used after the model has processed the padded input.

Parameters:
input_datatorch.Tensor

The tensor from which padding will be removed.

h_padlist[int]

Padding applied to the height, specified as [top, bottom].

w_padlist[int]

Padding applied to the width, specified as [left, right].

z_padlist[int]

Padding applied to the depth, specified as [front, back].

h_multint

The height as computed in the pad method.

w_multint

The width as computed in the pad method.

z_multint

The depth as computed in the pad method.

Returns:
torch.Tensor

The tensor with padding removed, restored to its original dimensions.

Return type:

Tensor

class direct.nn.unet.unet_3d.TransposeConvBlock3D(in_channels, out_channels)[source][source]#

Bases: Module

3D U-Net Transpose Convolutional Block.

forward(input_data)[source][source]#

Performs the forward pass of TransposeConvBlock3D.

Parameters:
input_datatorch.Tensor

Input data.

Returns:
torch.Tensor
Return type:

Tensor

training: bool#
class direct.nn.unet.unet_3d.UnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source][source]#

Bases: Module

PyTorch implementation of a 3D U-Net model.

This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional blocks. This is an extension to 3D volumes of direct.nn.unet.unet_2d.UnetModel2d.

forward(input_data)[source][source]#

Performs forward pass of UnetModel3d.

Parameters:
input_datatorch.Tensor

Input tensor of shape (N, in_channels, slice/time, height, width).

Returns:
torch.Tensor

Output of shape (N, out_channels, slice/time, height, width).

Return type:

Tensor

training: bool#
direct.nn.unet.unet_3d.pad_to_pow_of_2(inp, k)[source][source]#

Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2.

This is necessary for certain operations in the 3D U-Net architecture to maintain dimensionality.

Parameters:
inptorch.Tensor

The input tensor to be padded.

kint

The exponent to which the base of 2 is raised to determine the padding. Used to calculate the target dimension size as a power of 2.

Returns:
tuple[torch.Tensor, list[int]]

A tuple containing the padded tensor and a list of padding applied to each spatial dimension in the format [depth_front, depth_back, height_top, height_bottom, width_left, width_right].

Return type:

tuple[Tensor, list[int]]

Examples

>>> inp = torch.rand(1, 1, 15, 15, 15)  # A random tensor with shape [1, 1, 15, 15, 15]
>>> padded_inp, padding = pad_to_pow_of_2(inp, 4)
>>> print(padded_inp.shape, padding)
torch.Size([...]), [1, 1, 1, 1, 1, 1]

direct.nn.unet.unet_engine module#

Unet2d Models Engines for direct.

This module contains engines for Unet2d models, both for supervised and self-supervised learning.

class direct.nn.unet.unet_engine.Unet2dEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

Unet2d Model Engine.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

forward_function(data)[source][source]#

Forward function for Unet2dEngine.

Parameters:
datadict[str, Any]

Input data dictionary containing the following keys: “masked_kspace” and “sensitivity_map” if image initialization is “sense”.

Returns:
tuple[torch.Tensor, None]

Prediction of image and None for k-space.

Return type:

tuple[Tensor, None]

class direct.nn.unet.unet_engine.Unet2dJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: JSSLMRIModelEngine

JSSL Unet2d Model Engine.

Used for supplementary experiments for U-Net model with JSLL in the JSSL paper [1].

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

References

[1]

Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.

forward_function(data)[source][source]#

Forward function for Unet2dJSSLEngine.

Parameters:
datadict[str, Any]

Input data dictionary containing the following keys: “is_ssl” indicating SSL sample, “input_kspace” if SSL training, otherwise “masked_kspace”. Also contains “sensitivity_map” if image initialization is “sense”.

Returns:
tuple[torch.Tensor, None]

Prediction of image and None for k-space.

Return type:

tuple[Tensor, None]

class direct.nn.unet.unet_engine.Unet2dSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: SSLMRIModelEngine

SSL Unet2d Model Engine.

Used for supplementary experiments for U-Net model with SLL in the JSSL paper [1].

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

References

[1]

Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.

forward_function(data)[source][source]#

Forward function for Unet2dSSLEngine.

Parameters:
datadict[str, Any]

Input data dictionary containing the following keys: “input_kspace” if training, otherwise “masked_kspace”. Also contains “sensitivity_map” if image initialization is “sense”.

Returns:
tuple[torch.Tensor, None]

Prediction of image and None for k-space.

Return type:

tuple[Tensor, None]

Module contents#