direct.nn.vsharp package#
Submodules#
direct.nn.vsharp.config module#
- class direct.nn.vsharp.config.VSharpNet3DConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: 'int' = 8, num_steps_dc_gd: 'int' = 6, image_init: 'InitType' = <InitType.SENSE: 'sense'>, no_parameter_sharing: 'bool' = True, auxiliary_steps: 'int' = -1, initializer_channels: 'tuple[int, ...]' = (32, 32, 64, 64), initializer_dilations: 'tuple[int, ...]' = (1, 1, 2, 4), initializer_multiscale: 'int' = 1, initializer_activation: 'ActivationType' = <ActivationType.PRELU: 'prelu'>, unet_num_filters: 'int' = 32, unet_num_pool_layers: 'int' = 4, unet_dropout: 'float' = 0.0, unet_norm: 'bool' = False)[source][source]#
Bases:
ModelConfig
- auxiliary_steps: int = -1#
- image_init: InitType = 'sense'#
- initializer_activation: ActivationType = 'prelu'#
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64)#
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)#
- initializer_multiscale: int = 1#
- no_parameter_sharing: bool = True#
- num_steps: int = 8#
- num_steps_dc_gd: int = 6#
- unet_dropout: float = 0.0#
- unet_norm: bool = False#
- unet_num_filters: int = 32#
- unet_num_pool_layers: int = 4#
- class direct.nn.vsharp.config.VSharpNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: 'int' = 10, num_steps_dc_gd: 'int' = 8, image_init: 'InitType' = <InitType.SENSE: 'sense'>, no_parameter_sharing: 'bool' = True, auxiliary_steps: 'int' = 0, image_model_architecture: 'ModelName' = <ModelName.UNET: 'unet'>, initializer_channels: 'tuple[int, ...]' = (32, 32, 64, 64), initializer_dilations: 'tuple[int, ...]' = (1, 1, 2, 4), initializer_multiscale: 'int' = 1, initializer_activation: 'ActivationType' = <ActivationType.PRELU: 'prelu'>, image_resnet_hidden_channels: 'int' = 128, image_resnet_num_blocks: 'int' = 15, image_resnet_batchnorm: 'bool' = True, image_resnet_scale: 'float' = 0.1, image_unet_num_filters: 'int' = 32, image_unet_num_pool_layers: 'int' = 4, image_unet_dropout: 'float' = 0.0, image_didn_hidden_channels: 'int' = 16, image_didn_num_dubs: 'int' = 6, image_didn_num_convs_recon: 'int' = 9, image_conv_hidden_channels: 'int' = 64, image_conv_n_convs: 'int' = 15, image_conv_activation: 'str' = <ActivationType.RELU: 'relu'>, image_conv_batchnorm: 'bool' = False)[source][source]#
Bases:
ModelConfig
- auxiliary_steps: int = 0#
- image_conv_activation: str = 'relu'#
- image_conv_batchnorm: bool = False#
- image_conv_n_convs: int = 15#
- image_didn_num_convs_recon: int = 9#
- image_didn_num_dubs: int = 6#
- image_init: InitType = 'sense'#
- image_model_architecture: ModelName = 'unet'#
- image_resnet_batchnorm: bool = True#
- image_resnet_num_blocks: int = 15#
- image_resnet_scale: float = 0.1#
- image_unet_dropout: float = 0.0#
- image_unet_num_filters: int = 32#
- image_unet_num_pool_layers: int = 4#
- initializer_activation: ActivationType = 'prelu'#
- initializer_channels: tuple[int, ...] = (32, 32, 64, 64)#
- initializer_dilations: tuple[int, ...] = (1, 1, 2, 4)#
- initializer_multiscale: int = 1#
- no_parameter_sharing: bool = True#
- num_steps: int = 10#
- num_steps_dc_gd: int = 8#
direct.nn.vsharp.vsharp module#
This module provides the implementation of vSHARP model.
Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction of inverse-Problems (vSHARPP) model as presented in [1].
References#
George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction of inverse-Problems (2023). https://arxiv.org/abs/2309.09954.
- class direct.nn.vsharp.vsharp.LagrangeMultipliersInitializer(in_channels, out_channels, channels, dilations, multiscale_depth=1, activation=ActivationType.PRELU)[source][source]#
Bases:
Module
A convolutional neural network model that initializers the Lagrange multiplier of the
vSHARPNet
[1].More specifically, it produces an initial value for the Lagrange Multiplier based on the zero-filled image:
\[u^0 = \mathcal{G}_{\psi}(x^0).\]References
[1]George Yiasemis et al., “VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction of Inverse Problems” (2023). https://arxiv.org/abs/2309.09954.
- forward(x)[source][source]#
Forward pass of
LagrangeMultipliersInitializer
.- Parameters:
- xtorch.Tensor
Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
- torch.Tensor
Output tensor of shape (batch_size, out_channels, height, width).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.vsharp.vsharp.LagrangeMultipliersInitializer3D(in_channels, out_channels, channels, dilations, multiscale_depth=1, activation=ActivationType.PRELU)[source][source]#
Bases:
Module
A convolutional neural network model that initializes the Lagrange multiplier of
VSharpNet3D
.This is an extension to 3D data of
LagrangeMultipliersInitializer
.- forward(x)[source][source]#
Forward pass of
LagrangeMultipliersInitializer3D
.- Parameters:
- xtorch.Tensor
Input tensor of shape (batch_size, in_channels, z, x, y).
- Returns:
- torch.Tensor
Output tensor of shape (batch_size, out_channels, z, x, y).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.vsharp.vsharp.VSharpNet(forward_operator, backward_operator, num_steps, num_steps_dc_gd, image_init=InitType.SENSE, no_parameter_sharing=True, image_model_architecture=ModelName.UNET, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, initializer_activation=ActivationType.PRELU, auxiliary_steps=0, **kwargs)[source][source]#
Bases:
Module
Variable Splitting Half-quadratic ADMM algorithm for Reconstruction of Parallel MRI [1].
Variable Splitting Half Quadratic VSharpNet is a deep learning model that solves the augmented Lagrangian derivation of the variable half quadratic splitting problem using ADMM (Alternating Direction Method of Multipliers). It is specifically designed for solving inverse problems in magnetic resonance imaging (MRI).
The VSharpNet model incorporates an iterative optimization algorithm that consists of three steps: z-step, x-step, and u-step. These steps are detailed mathematically as follows:
\[z^{t+1} = \mathrm{argmin}_{z} \lambda \mathcal{G}(z) + \frac{\rho}{2} || x^{t} - z + \frac{u^t}{\rho} ||_2^2 \quad \mathrm{[z-step]}\]\[x^{t+1} = \mathrm{argmin}_{x} \frac{1}{2} || \mathcal{A}_{\mathbf{U},\mathbf{S}}(x) - \tilde{y} ||_2^2 + \frac{\rho}{2} || x - z^{t+1} + \frac{u^t}{\rho} ||_2^2 \quad \mathrm{[x-step]}\]\[u^{t+1} = u^t + \rho (x^{t+1} - z^{t+1}) \quad \mathrm{[u-step]}\]During the z-step, the model minimizes the augmented Lagrangian function with respect to z, utilizing DL-based denoisers. In the x-step, it optimizes x by minimizing the data consistency term through unrolling a gradient descent scheme (DC-GD). The u-step involves updating the Lagrange multiplier u. These steps are iterated for a specified number of cycles.
The model includes an initializer for Lagrange multipliers.
It also allows for outputting auxiliary steps.
VSharpNet
is tailored for 2D MRI data reconstruction.References
[1]George Yiasemis et al., “VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction of Inverse Problems” (2023). https://arxiv.org/abs/2309.09954.
- forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#
Computes forward pass of
VSharpNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- Returns:
- outlist of torch.Tensors
List of output images of shape (N, height, width, complex=2).
- Return type:
list
[Tensor
]
-
training:
bool
#
- class direct.nn.vsharp.vsharp.VSharpNet3D(forward_operator, backward_operator, num_steps, num_steps_dc_gd, image_init=InitType.SENSE, no_parameter_sharing=True, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, initializer_activation=ActivationType.PRELU, auxiliary_steps=-1, unet_num_filters=32, unet_num_pool_layers=4, unet_dropout=0.0, unet_norm=False, **kwargs)[source][source]#
Bases:
Module
VharpNet 3D version using 3D U-Nets as denoisers.
This is an extension to 3D of
VSharpNet
. For the original paper refer to [1].References
[1]George Yiasemis et al., “VSHARP: Variable Splitting Half-quadratic ADMM Algorithm for Reconstruction of Inverse Problems” (2023). https://arxiv.org/abs/2309.09954.
- forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#
Computes forward pass of
VSharpNet3D
.- Parameters:
- masked_kspacetorch.Tensor
Masked k-space of shape (N, coil, slice, height, width, complex=2).
- sensitivity_maptorch.Tensor
Sensitivity map of shape (N, coil, slice, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, 1 or slice, height, width, 1).
- Returns:
- outlist of torch.Tensors
List of output images each of shape (N, slice, height, width, complex=2).
- Return type:
list
[Tensor
]
-
training:
bool
#
direct.nn.vsharp.vsharp_engine module#
Engines for vSHARP 2D and 3D models [1].
Includes supervised, self-supervised and joint supervised and self-supervised learning [2] engines.
References#
Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- class direct.nn.vsharp.vsharp_engine.VSharpNet3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
VSharpNet 3D Model Engine.
- class direct.nn.vsharp.vsharp_engine.VSharpNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
VSharpNet 2D Model Engine.
- class direct.nn.vsharp.vsharp_engine.VSharpNetJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
JSSLMRIModelEngine
Joint Supervised and Self-supervised Learning vSHARP Model 2D Engine.
Used for the main experiments in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
VSharpNetJSSLEngine
.- Return type:
None
- class direct.nn.vsharp.vsharp_engine.VSharpNetSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
SSLMRIModelEngine
Self-supervised Learning vSHARP Model 2D Engine.
Used for the main experiments for SSL in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
VSharpNetSSLEngine
.- Return type:
None