direct.nn.unet package#
Submodules#
direct.nn.unet.config module#
- class direct.nn.unet.config.UnetModel2dConfig(model_name='???', engine_name=None, in_channels=2, out_channels=2, num_filters=16, num_pool_layers=4, dropout_probability=0.0)[source]#
Bases:
ModelConfig- in_channels = 2#
- out_channels = 2#
- num_filters = 16#
- num_pool_layers = 4#
- dropout_probability = 0.0#
- __init__(model_name='???', engine_name=None, in_channels=2, out_channels=2, num_filters=16, num_pool_layers=4, dropout_probability=0.0)#
- class direct.nn.unet.config.NormUnetModel2dConfig(model_name='???', engine_name=None)[source]#
Bases:
ModelConfig- in_channels = 2#
- out_channels = 2#
- num_filters = 16#
- num_pool_layers = 4#
- dropout_probability = 0.0#
- norm_groups = 2#
- class direct.nn.unet.config.Unet2dConfig(model_name='???', engine_name=None, num_filters=16, num_pool_layers=4, dropout_probability=0.0, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED)[source]#
Bases:
ModelConfig- num_filters = 16#
- num_pool_layers = 4#
- dropout_probability = 0.0#
- skip_connection = False#
- normalized = False#
- image_initialization = 'zero_filled'#
- __init__(model_name='???', engine_name=None, num_filters=16, num_pool_layers=4, dropout_probability=0.0, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED)#
- class direct.nn.unet.config.UnetModel3dConfig(model_name='???', engine_name=None, in_channels=2, out_channels=2, num_filters=16, num_pool_layers=4, dropout_probability=0.0)[source]#
Bases:
ModelConfig- in_channels = 2#
- out_channels = 2#
- num_filters = 16#
- num_pool_layers = 4#
- dropout_probability = 0.0#
- __init__(model_name='???', engine_name=None, in_channels=2, out_channels=2, num_filters=16, num_pool_layers=4, dropout_probability=0.0)#
direct.nn.unet.unet_2d module#
- class direct.nn.unet.unet_2d.ConvBlock(in_channels, out_channels, dropout_probability)[source]#
Bases:
ModuleU-Net convolutional block.
It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout.
- __init__(in_channels, out_channels, dropout_probability)[source]#
Inits
ConvBlock.- Parameters:
in_channels (
int) – Number of input channels.out_channels (
int) – Number of output channels.dropout_probability (
float) – Dropout probability.
- class direct.nn.unet.unet_2d.TransposeConvBlock(in_channels, out_channels)[source]#
Bases:
ModuleU-Net Transpose Convolutional Block.
It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation.
- __init__(in_channels, out_channels)[source]#
Inits
TransposeConvBlock.- Parameters:
in_channels (
int) – Number of input channels.out_channels (
int) – Number of output channels.
- forward(input_data)[source]#
Performs forward pass of
TransposeConvBlock.- Parameters:
input_data (
Tensor) – Input tensor.- Return type:
Tensor- Returns:
Output tensor.
- __repr__()[source]#
Representation of
TransposeConvBlock.
- class direct.nn.unet.unet_2d.UnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source]#
Bases:
ModulePyTorch implementation of a U-Net model based on [1]_.
References
- __init__(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source]#
Inits
UnetModel2d.- Parameters:
in_channels (
int) – Number of input channels to the u-net.out_channels (
int) – Number of output channels to the u-net.num_filters (
int) – Number of output channels of the first convolutional layer.num_pool_layers (
int) – Number of down-sampling and up-sampling layers (depth).dropout_probability (
float) – Dropout probability.
- forward(input_data)[source]#
Performs forward pass of
UnetModel2d.- Parameters:
input_data (
Tensor) – torch.Tensor- Return type:
Tensor- Returns:
torch.Tensor
- class direct.nn.unet.unet_2d.NormUnetModel2d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source]#
Bases:
ModuleImplementation of a Normalized U-Net model.
- __init__(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source]#
Inits
NormUnetModel2d.- Parameters:
in_channels (
int) – intu-net. (Number of output channels to the)
out_channels (
int) – intu-net.
num_filters (
int) – intlayer. (Number of output channels of the first convolutional)
num_pool_layers (
int) – intlayers (Number of down-sampling and up-sampling)
dropout_probability (
float) – floatprobability. (Dropout)
norm_groups (
int) – int,groups. (Number of normalization)
- static norm(input_data, groups)[source]#
Performs group normalization.
- Return type:
Tuple[Tensor,Tensor,Tensor]
- forward(input_data)[source]#
Performs forward pass of
NormUnetModel2d.- Parameters:
input_data (
Tensor) – torch.Tensor- Return type:
Tensor- Returns:
torch.Tensor
- class direct.nn.unet.unet_2d.Unet2d(forward_operator, backward_operator, num_filters, num_pool_layers, dropout_probability, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED, **kwargs)[source]#
Bases:
ModulePyTorch implementation of a U-Net model for MRI Reconstruction.
- __init__(forward_operator, backward_operator, num_filters, num_pool_layers, dropout_probability, skip_connection=False, normalized=False, image_initialization=InitType.ZERO_FILLED, **kwargs)[source]#
Inits
Unet2d.- Parameters:
forward_operator (
Callable) – CallableOperator. (Backward)
backward_operator (
Callable) – CallableOperator.
num_filters (
int) – intfilters. (Number of first layer)
num_pool_layers (
int) – intlayers. (Number of pooling)
dropout_probability (
float) – floatprobability. (Dropout)
skip_connection (
bool) – boolTrue (If) – False.
Default (Type of image initialization.) – False.
normalized (
bool) – boolTrue – False.
Default – False.
image_initialization (
InitType) – InitTypeDefault – InitType.ZERO_FILLED.
kwargs – dict
- compute_sense_init(kspace, sensitivity_map)[source]#
Computes sense initialization \(x_{\text{SENSE}}\):
\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]where \(y^k\) denotes the data from coil \(k\).
- Parameters:
kspace (
Tensor) – torch.Tensorshape (Sensitivity map of)
sensitivity_map (
Tensor) – torch.Tensorshape
- Returns:
torch.Tensor Sense initialization \(x_{\text{SENSE}}\).
- Return type:
input_image
- forward(masked_kspace, sensitivity_map=None)[source]#
Computes forward pass of Unet2d.
- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Sensitivity map of)
sensitivity_map (
Optional[Tensor]) – torch.Tensorshape – None.
- Returns:
torch.Tensor Output image of shape (N, height, width, complex=2).
- Return type:
output
direct.nn.unet.unet_3d module#
Code for three-dimensional U-Net adapted from the 2D variant.
- class direct.nn.unet.unet_3d.ConvBlock3D(in_channels, out_channels, dropout_probability)[source]#
Bases:
Module3D U-Net convolutional block.
- __init__(in_channels, out_channels, dropout_probability)[source]#
Inits
ConvBlock3D.- Parameters:
in_channels (
int) – inttensor. (Number of channels in the input)
out_channels (
int) – intlayers. (Dropout probability applied after convolutional)
dropout_probability (
float) – floatlayers.
- forward(input_data)[source]#
Performs the forward pass of
ConvBlock3D..- Parameters:
input_data (
Tensor) – torch.Tensordata. (Input)
- Return type:
Tensor- Returns:
torch.Tensor
- class direct.nn.unet.unet_3d.TransposeConvBlock3D(in_channels, out_channels)[source]#
Bases:
Module3D U-Net Transpose Convolutional Block.
- __init__(in_channels, out_channels)[source]#
Inits
TransposeConvBlock3D.- Parameters:
in_channels (
int) – inttensor. (Number of channels in the input)
out_channels (
int) – intlayers. (Number of channels produced by the convolutional)
- forward(input_data)[source]#
Performs the forward pass of
TransposeConvBlock3D.- Parameters:
input_data (
Tensor) – torch.Tensordata. (Input)
- Return type:
Tensor- Returns:
torch.Tensor
- class direct.nn.unet.unet_3d.UnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source]#
Bases:
ModulePyTorch implementation of a 3D U-Net model.
This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional blocks. This is an extension to 3D volumes of
direct.nn.unet.unet_2d.UnetModel2d.- __init__(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability)[source]#
Inits
UnetModel3d.- Parameters:
in_channels (
int) – intchannels. (Number of output)
out_channels (
int) – intchannels.
num_filters (
int) – intlayer. (Number of output channels of the first convolutional)
num_pool_layers (
int) – intlayers (Number of down-sampling and up-sampling)
dropout_probability (
float) – floatprobability. (Dropout)
- forward(input_data)[source]#
Performs forward pass of
UnetModel3d.- Parameters:
input_data (
Tensor) – torch.Tensorshape (Input tensor of)
- Return type:
Tensor- Returns:
torch.Tensor Output of shape (N, out_channels, slice/time, height, width).
- class direct.nn.unet.unet_3d.NormUnetModel3d(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source]#
Bases:
ModuleImplementation of a Normalized U-Net model for 3D data.
This is an extension to 3D volumes of
direct.nn.unet.unet_2d.NormUnetModel2d.- __init__(in_channels, out_channels, num_filters, num_pool_layers, dropout_probability, norm_groups=2)[source]#
Inits
NormUnetModel3D.- Parameters:
in_channels (
int) – intchannels. (Number of output)
out_channels (
int) – intchannels.
num_filters (
int) – intlayer. (Number of output channels of the first convolutional)
num_pool_layers (
int) – intlayers (Number of down-sampling and up-sampling)
dropout_probability (
float) – floatprobability. (Dropout)
norm_groups (
int) – int,groups. (Number of normalization)
- static norm(input_data, groups)[source]#
Applies group normalization for 3D data.
- Parameters:
input_data (
Tensor) – torch.Tensornormalize. (The input tensor to)
groups (
int) – intnormalization. (The number of groups to divide the tensor into for)
- Return type:
tuple[Tensor,Tensor,Tensor]- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing the normalized tensor, the mean, and the standard deviation used for normalization.
- static unnorm(input_data, mean, std, groups)[source]#
Reverts the normalization applied to the 3D tensor.
- Parameters:
input_data (
Tensor) – torch.Tensoron. (The normalized tensor to revert normalization)
mean (
Tensor) – torch.Tensornormalization. (The number of groups the tensor was divided into during)
std (
Tensor) – torch.Tensornormalization.
groups (
int) – intnormalization.
- Return type:
Tensor- Returns:
torch.Tensor The tensor after reverting the normalization.
- static pad(input_data)[source]#
Applies padding to the input 3D tensor to ensure its dimensions are multiples of 16.
- Parameters:
input_data (
Tensor) – torch.Tensorpad. (The input tensor to)
- Return type:
tuple[Tensor,tuple[list[int],list[int],int,int,list[int],list[int]]]- Returns:
tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]] A tuple containing the padded tensor and a tuple with the padding applied to each dimension (height, width, depth) and the target dimensions after padding.
- static unpad(input_data, h_pad, w_pad, z_pad, h_mult, w_mult, z_mult)[source]#
Removes padding from the 3D input tensor, reverting it to its original dimensions before padding was applied.
This method is typically used after the model has processed the padded input.
- Parameters:
input_data (
Tensor) – torch.Tensorremoved. (The tensor from which padding will be)
h_pad (
list[int]) – list[int]height (Padding applied to the)
[top (specified as)
bottom].
w_pad (
list[int]) – list[int]width (Padding applied to the)
[left (specified as)
right].
z_pad (
list[int]) – list[int]depth (Padding applied to the)
[front (specified as)
back].
h_mult (
int) – intmethod. (The depth as computed in the pad)
w_mult (
int) – intmethod.
z_mult (
int) – intmethod.
- Return type:
Tensor- Returns:
torch.Tensor The tensor with padding removed, restored to its original dimensions.
- direct.nn.unet.unet_3d.pad_to_pow_of_2(inp, k)[source]#
Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2.
This is necessary for certain operations in the 3D U-Net architecture to maintain dimensionality.
- Parameters:
inp (
Tensor) – torch.Tensorpadded. (The input tensor to be)
k (
int) – intcalculate (The exponent to which the base of 2 is raised to determine the padding. Used to)
2. (the target dimension size as a power of)
- Return type:
tuple[Tensor,list[int]]- Returns:
tuple[torch.Tensor, list[int]] A tuple containing the padded tensor and a list of padding applied to each spatial dimension in the format [depth_front, depth_back, height_top, height_bottom, width_left, width_right].
Examples: >>> inp = torch.rand(1, 1, 15, 15, 15) # A random tensor with shape [1, 1, 15, 15, 15] >>> padded_inp, padding = pad_to_pow_of_2(inp, 4) >>> print(padded_inp.shape, padding) torch.Size([…]), [1, 1, 1, 1, 1, 1]
direct.nn.unet.unet_engine module#
Unet2d Models Engines for direct.
This module contains engines for Unet2d models, both for supervised and self-supervised learning.
- class direct.nn.unet.unet_engine.Unet2dEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
MRIModelEngineUnet2d Model Engine.
- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
Unet2dEngine.- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
- forward_function(data)[source]#
Forward function for
Unet2dEngine.- Parameters:
data (
dict[str,Any]) – Input data dictionary containing the following keys: “masked_kspace” and “sensitivity_map”"sense". (if image initialization is)
- Return type:
tuple[Tensor,None]- Returns:
Prediction of image and None for k-space.
- class direct.nn.unet.unet_engine.Unet2dSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
SSLMRIModelEngineSSL Unet2d Model Engine.
Used for supplementary experiments for U-Net model with SLL in the JSSL paper [1].
- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
References
[1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and
Self-supervised Learning for MRI Reconstruction, http: //arxiv.org/abs/2311.15856, (2023). https: //doi.org/10.48550/arXiv.2311.15856.
- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
Unet2dSSLEngine.- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
- forward_function(data)[source]#
Forward function for
Unet2dSSLEngine.- Parameters:
data (
dict[str,Any]) – Input data dictionary containing the following keys: “input_kspace” if training,"sense". (otherwise "masked_kspace". Also contains "sensitivity_map" if image initialization is)
- Return type:
tuple[Tensor,None]- Returns:
Prediction of image and None for k-space.
- class direct.nn.unet.unet_engine.Unet2dJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
JSSLMRIModelEngineJSSL Unet2d Model Engine.
Used for supplementary experiments for U-Net model with JSLL in the JSSL paper [1].
- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
References
[1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and
Self-supervised Learning for MRI Reconstruction, http: //arxiv.org/abs/2311.15856, (2023). https: //doi.org/10.48550/arXiv.2311.15856.
- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
Unet2dJSSLEngine.- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.
- forward_function(data)[source]#
Forward function for
Unet2dJSSLEngine.- Parameters:
data (
dict[str,Any]) – Input data dictionary containing the following keys: “is_ssl” indicating SSL sample, “input_kspace” if SSLtraining
"sense". (otherwise "masked_kspace". Also contains "sensitivity_map" if image initialization is)
- Return type:
tuple[Tensor,None]- Returns:
Prediction of image and None for k-space.