direct.nn.unet package#
Submodules#
direct.nn.unet.config module#
- class direct.nn.unet.config.NormUnetModel2dConfig(model_name='???', engine_name=None)[source][source]#
Bases:
ModelConfig
-
dropout_probability:
float
= 0.0#
-
in_channels:
int
= 2#
-
norm_groups:
int
= 2#
-
num_filters:
int
= 16#
-
num_pool_layers:
int
= 4#
-
out_channels:
int
= 2#
-
dropout_probability:
- class direct.nn.unet.config.Unet2dConfig(model_name: str = '???', engine_name: Optional[str] = None, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0, skip_connection: bool = False, normalized: bool = False, image_initialization: direct.nn.types.InitType = <InitType.ZERO_FILLED: 'zero_filled'>)[source][source]#
Bases:
ModelConfig
-
dropout_probability:
float
= 0.0#
-
normalized:
bool
= False#
-
num_filters:
int
= 16#
-
num_pool_layers:
int
= 4#
-
skip_connection:
bool
= False#
-
dropout_probability:
- class direct.nn.unet.config.UnetModel2dConfig(model_name: str = '???', engine_name: str | None = None, in_channels: int = 2, out_channels: int = 2, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0)[source][source]#
Bases:
ModelConfig
-
dropout_probability:
float
= 0.0#
-
in_channels:
int
= 2#
-
num_filters:
int
= 16#
-
num_pool_layers:
int
= 4#
-
out_channels:
int
= 2#
-
dropout_probability:
- class direct.nn.unet.config.UnetModel3dConfig(model_name: str = '???', engine_name: str | None = None, in_channels: int = 2, out_channels: int = 2, num_filters: int = 16, num_pool_layers: int = 4, dropout_probability: float = 0.0)[source][source]#
Bases:
ModelConfig
-
dropout_probability:
float
= 0.0#
-
in_channels:
int
= 2#
-
num_filters:
int
= 16#
-
num_pool_layers:
int
= 4#
-
out_channels:
int
= 2#
-
dropout_probability:
direct.nn.unet.unet_2d module#
- class direct.nn.unet.unet_2d.ConvBlock(in_channels, out_channels, dropout_probability)[source][source]#
Bases:
Module
U-Net convolutional block.
It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout.
- forward(input_data)[source][source]#
Performs the forward pass of
ConvBlock
.- Parameters:
- input_data: torch.Tensor
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.unet.unet_2d.NormUnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source][source]#
Bases:
Module
Implementation of a Normalized U-Net model.
- forward(input_data)[source][source]#
Performs forward pass of
NormUnetModel2d
.- Parameters:
- input_data: torch.Tensor
- Returns:
- torch.Tensor
- Return type:
Tensor
- static norm(input_data, groups)[source][source]#
Performs group normalization.
- Return type:
Tuple
[Tensor
,Tensor
,Tensor
]
- static pad(input_data)[source][source]#
- Return type:
Tuple
[Tensor
,Tuple
[List
[int
],List
[int
],int
,int
]]
-
training:
bool
#
- class direct.nn.unet.unet_2d.TransposeConvBlock(in_channels, out_channels)[source][source]#
Bases:
Module
U-Net Transpose Convolutional Block.
It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation.
- forward(input_data)[source][source]#
Performs forward pass of
TransposeConvBlock
.- Parameters:
- input_data: torch.Tensor
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.unet.unet_2d.Unet2d(forward_operator, backward_operator, num_filters, num_pool_layers, dropout_probability, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED, **kwargs)[source][source]#
Bases:
Module
PyTorch implementation of a U-Net model for MRI Reconstruction.
- compute_sense_init(kspace, sensitivity_map)[source][source]#
Computes sense initialization \(x_{\text{SENSE}}\): :rtype:
Tensor
\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]where \(y^k\) denotes the data from coil \(k\).
- Parameters:
- kspace: torch.Tensor
k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- Returns:
- input_image: torch.Tensor
Sense initialization \(x_{\text{SENSE}}\).
- forward(masked_kspace, sensitivity_map=None)[source][source]#
Computes forward pass of Unet2d.
- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.
- Returns:
- output: torch.Tensor
Output image of shape (N, height, width, complex=2).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.unet.unet_2d.UnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source][source]#
Bases:
Module
PyTorch implementation of a U-Net model based on [1].
References
[1]Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015, edited by Nassir Navab et al., Springer International Publishing, 2015, pp. 234–41. Springer Link, https://doi.org/10.1007/978-3-319-24574-4_28.
- forward(input_data)[source][source]#
Performs forward pass of
UnetModel2d
.- Parameters:
- input_data: torch.Tensor
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
direct.nn.unet.unet_3d module#
Code for three-dimensional U-Net adapted from the 2D variant.
- class direct.nn.unet.unet_3d.ConvBlock3D(in_channels, out_channels, dropout_probability)[source][source]#
Bases:
Module
3D U-Net convolutional block.
- forward(input_data)[source][source]#
Performs the forward pass of
ConvBlock3D
..- Parameters:
- input_datatorch.Tensor
Input data.
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.unet.unet_3d.NormUnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source][source]#
Bases:
Module
Implementation of a Normalized U-Net model for 3D data.
This is an extension to 3D volumes of
direct.nn.unet.unet_2d.NormUnetModel2d
.- forward(input_data)[source][source]#
Performs the forward pass of
NormUnetModel3D
.- Parameters:
- input_datatorch.Tensor
Input tensor of shape (N, in_channels, slice/time, height, width).
- Returns:
- torch.Tensor
Output of shape (N, out_channels, slice/time, height, width).
- Return type:
Tensor
- static norm(input_data, groups)[source][source]#
Applies group normalization for 3D data.
- Parameters:
- input_datatorch.Tensor
The input tensor to normalize.
- groupsint
The number of groups to divide the tensor into for normalization.
- Returns:
- tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing the normalized tensor, the mean, and the standard deviation used for normalization.
- Return type:
tuple
[Tensor
,Tensor
,Tensor
]
- static pad(input_data)[source][source]#
Applies padding to the input 3D tensor to ensure its dimensions are multiples of 16.
- Parameters:
- input_datatorch.Tensor
The input tensor to pad.
- Returns:
- tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]]
A tuple containing the padded tensor and a tuple with the padding applied to each dimension (height, width, depth) and the target dimensions after padding.
- Return type:
tuple
[Tensor
,tuple
[list
[int
],list
[int
],int
,int
,list
[int
],list
[int
]]]
-
training:
bool
#
- static unnorm(input_data, mean, std, groups)[source][source]#
Reverts the normalization applied to the 3D tensor.
- Parameters:
- input_datatorch.Tensor
The normalized tensor to revert normalization on.
- meantorch.Tensor
The mean used during normalization.
- stdtorch.Tensor
The standard deviation used during normalization.
- groupsint
The number of groups the tensor was divided into during normalization.
- Returns:
- torch.Tensor
The tensor after reverting the normalization.
- Return type:
Tensor
- static unpad(input_data, h_pad, w_pad, z_pad, h_mult, w_mult, z_mult)[source][source]#
Removes padding from the 3D input tensor, reverting it to its original dimensions before padding was applied.
This method is typically used after the model has processed the padded input.
- Parameters:
- input_datatorch.Tensor
The tensor from which padding will be removed.
- h_padlist[int]
Padding applied to the height, specified as [top, bottom].
- w_padlist[int]
Padding applied to the width, specified as [left, right].
- z_padlist[int]
Padding applied to the depth, specified as [front, back].
- h_multint
The height as computed in the pad method.
- w_multint
The width as computed in the pad method.
- z_multint
The depth as computed in the pad method.
- Returns:
- torch.Tensor
The tensor with padding removed, restored to its original dimensions.
- Return type:
Tensor
- class direct.nn.unet.unet_3d.TransposeConvBlock3D(in_channels, out_channels)[source][source]#
Bases:
Module
3D U-Net Transpose Convolutional Block.
- forward(input_data)[source][source]#
Performs the forward pass of
TransposeConvBlock3D
.- Parameters:
- input_datatorch.Tensor
Input data.
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.unet.unet_3d.UnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source][source]#
Bases:
Module
PyTorch implementation of a 3D U-Net model.
This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional blocks. This is an extension to 3D volumes of
direct.nn.unet.unet_2d.UnetModel2d
.- forward(input_data)[source][source]#
Performs forward pass of
UnetModel3d
.- Parameters:
- input_datatorch.Tensor
Input tensor of shape (N, in_channels, slice/time, height, width).
- Returns:
- torch.Tensor
Output of shape (N, out_channels, slice/time, height, width).
- Return type:
Tensor
-
training:
bool
#
- direct.nn.unet.unet_3d.pad_to_pow_of_2(inp, k)[source][source]#
Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2.
This is necessary for certain operations in the 3D U-Net architecture to maintain dimensionality.
- Parameters:
- inptorch.Tensor
The input tensor to be padded.
- kint
The exponent to which the base of 2 is raised to determine the padding. Used to calculate the target dimension size as a power of 2.
- Returns:
- tuple[torch.Tensor, list[int]]
A tuple containing the padded tensor and a list of padding applied to each spatial dimension in the format [depth_front, depth_back, height_top, height_bottom, width_left, width_right].
- Return type:
tuple
[Tensor
,list
[int
]]
Examples
>>> inp = torch.rand(1, 1, 15, 15, 15) # A random tensor with shape [1, 1, 15, 15, 15] >>> padded_inp, padding = pad_to_pow_of_2(inp, 4) >>> print(padded_inp.shape, padding) torch.Size([...]), [1, 1, 1, 1, 1, 1]
direct.nn.unet.unet_engine module#
Unet2d Models Engines for direct.
This module contains engines for Unet2d models, both for supervised and self-supervised learning.
- class direct.nn.unet.unet_engine.Unet2dEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
Unet2d Model Engine.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- forward_function(data)[source][source]#
Forward function for
Unet2dEngine
.- Parameters:
- datadict[str, Any]
Input data dictionary containing the following keys: “masked_kspace” and “sensitivity_map” if image initialization is “sense”.
- Returns:
- tuple[torch.Tensor, None]
Prediction of image and None for k-space.
- Return type:
tuple
[Tensor
,None
]
- class direct.nn.unet.unet_engine.Unet2dJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
JSSLMRIModelEngine
JSSL Unet2d Model Engine.
Used for supplementary experiments for U-Net model with JSLL in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
Unet2dJSSLEngine
.- Parameters:
- datadict[str, Any]
Input data dictionary containing the following keys: “is_ssl” indicating SSL sample, “input_kspace” if SSL training, otherwise “masked_kspace”. Also contains “sensitivity_map” if image initialization is “sense”.
- Returns:
- tuple[torch.Tensor, None]
Prediction of image and None for k-space.
- Return type:
tuple
[Tensor
,None
]
- class direct.nn.unet.unet_engine.Unet2dSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
SSLMRIModelEngine
SSL Unet2d Model Engine.
Used for supplementary experiments for U-Net model with SLL in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
Unet2dSSLEngine
.- Parameters:
- datadict[str, Any]
Input data dictionary containing the following keys: “input_kspace” if training, otherwise “masked_kspace”. Also contains “sensitivity_map” if image initialization is “sense”.
- Returns:
- tuple[torch.Tensor, None]
Prediction of image and None for k-space.
- Return type:
tuple
[Tensor
,None
]