direct.nn.varsplitnet package#
Submodules#
direct.nn.varsplitnet.config module#
- class direct.nn.varsplitnet.config.MRIVarSplitNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps_reg: int = 8, num_steps_dc: int = 8, image_init: str = 'sense', no_parameter_sharing: bool = True, kspace_no_parameter_sharing: bool = True, image_model_architecture: str = <ModelName.UNET: 'unet'>, kspace_model_architecture: Optional[str] = None, image_resnet_hidden_channels: Optional[int] = 128, image_resnet_num_blocks: Optional[int] = 15, image_resnet_batchnorm: Optional[bool] = True, image_resnet_scale: Optional[float] = 0.1, image_unet_num_filters: Optional[int] = 32, image_unet_num_pool_layers: Optional[int] = 4, image_unet_dropout: Optional[float] = 0.0, image_didn_hidden_channels: Optional[int] = 16, image_didn_num_dubs: Optional[int] = 6, image_didn_num_convs_recon: Optional[int] = 9, kspace_resnet_hidden_channels: Optional[int] = 64, kspace_resnet_num_blocks: Optional[int] = 1, kspace_resnet_batchnorm: Optional[bool] = True, kspace_resnet_scale: Optional[float] = 0.1, kspace_unet_num_filters: Optional[int] = 16, kspace_unet_num_pool_layers: Optional[int] = 4, kspace_unet_dropout: Optional[float] = 0.0, kspace_didn_hidden_channels: Optional[int] = 8, kspace_didn_num_dubs: Optional[int] = 6, kspace_didn_num_convs_recon: Optional[int] = 9, image_conv_hidden_channels: Optional[int] = 64, image_conv_n_convs: Optional[int] = 15, image_conv_activation: Optional[str] = <ActivationType.RELU: 'relu'>, image_conv_batchnorm: Optional[bool] = False, kspace_conv_hidden_channels: Optional[int] = 64, kspace_conv_n_convs: Optional[int] = 15, kspace_conv_activation: Optional[str] = <ActivationType.PRELU: 'prelu'>, kspace_conv_batchnorm: Optional[bool] = False)[source][source]#
Bases:
ModelConfig
-
image_conv_activation:
Optional
[str
] = 'relu'#
-
image_conv_batchnorm:
Optional
[bool
] = False#
-
image_conv_n_convs:
Optional
[int
] = 15#
-
image_didn_num_convs_recon:
Optional
[int
] = 9#
-
image_didn_num_dubs:
Optional
[int
] = 6#
-
image_init:
str
= 'sense'#
-
image_model_architecture:
str
= 'unet'#
-
image_resnet_batchnorm:
Optional
[bool
] = True#
-
image_resnet_num_blocks:
Optional
[int
] = 15#
-
image_resnet_scale:
Optional
[float
] = 0.1#
-
image_unet_dropout:
Optional
[float
] = 0.0#
-
image_unet_num_filters:
Optional
[int
] = 32#
-
image_unet_num_pool_layers:
Optional
[int
] = 4#
-
kspace_conv_activation:
Optional
[str
] = 'prelu'#
-
kspace_conv_batchnorm:
Optional
[bool
] = False#
-
kspace_conv_n_convs:
Optional
[int
] = 15#
-
kspace_didn_num_convs_recon:
Optional
[int
] = 9#
-
kspace_didn_num_dubs:
Optional
[int
] = 6#
-
kspace_model_architecture:
Optional
[str
] = None#
-
kspace_no_parameter_sharing:
bool
= True#
-
kspace_resnet_batchnorm:
Optional
[bool
] = True#
-
kspace_resnet_num_blocks:
Optional
[int
] = 1#
-
kspace_resnet_scale:
Optional
[float
] = 0.1#
-
kspace_unet_dropout:
Optional
[float
] = 0.0#
-
kspace_unet_num_filters:
Optional
[int
] = 16#
-
kspace_unet_num_pool_layers:
Optional
[int
] = 4#
-
no_parameter_sharing:
bool
= True#
-
num_steps_dc:
int
= 8#
-
num_steps_reg:
int
= 8#
-
image_conv_activation:
direct.nn.varsplitnet.varsplitnet module#
- class direct.nn.varsplitnet.varsplitnet.MRIVarSplitNet(forward_operator, backward_operator, num_steps_reg, num_steps_dc, image_init=InitType.SENSE, no_parameter_sharing=True, image_model_architecture=ModelName.UNET, kspace_no_parameter_sharing=True, kspace_model_architecture=None, **kwargs)[source][source]#
Bases:
Module
MRI reconstruction network that solves the variable split optimisation problem.
It solves the following:
\[ \begin{align}\begin{aligned}z^{i-1} = \arg \min_{z} \mu * ||x^{i-1} - z||_2^2 + \mathcal{R}(z)\\x^{i} = \arg \min_{x} ||y - A(x)||_2^2 + \mu * ||x - z^{i-1}||_2^2\end{aligned}\end{align} \]by unrolling twice using the gradient descent algorithm and replacing \(R\) with a neural network. More specifically, for \(z_0, x_0 = \text{SENSE}(\tilde{y})\):
\[z^{i} = \alpha_{i-1} \times f_{\theta_{i-1}}\Big(\mu(z^{i-1} - x^{i-1}), z^{i-1}\Big), \quad i=1,\cdots,T_{reg}\]where \(x^{i}\) is the output of
\[(x^{i})^{j} = (x^{i})^{j-1} - \beta_{j-1} \Big[ A^{*}\big( A( (x^{i})^{j-1} ) - \tilde{y} \big) + \mu ((x^{i})^{j-1} - z^{i}) \Big], \quad j=1,\cdots,T_{dc},\]i.e. \(x^{i}=(x^{i}^{T_{reg}})\).
- compute_model_per_coil(model, data)[source][source]#
Performs forward pass of model per coil.
- Parameters:
- model: nn.Module
Model to run.
- data: torch.Tensor
Multi-coil data of shape (batch, coil, complex=2, height, width).
- Returns:
- output: torch.Tensor
Computed output per coil.
- Return type:
Tensor
- forward(masked_kspace, sensitivity_map, sampling_mask, scaling_factor=None)[source][source]#
Computes forward pass of
MRIVarSplitNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.
- sampling_mask: torch.Tensor
- scaling_factor: torch.Tensor
- Returns:
- image: torch.Tensor
Output image of shape (N, height, width, complex=2).
- Return type:
Tensor
-
training:
bool
#
direct.nn.varsplitnet.varsplitnet_engine module#
- class direct.nn.varsplitnet.varsplitnet_engine.MRIVarSplitNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
MRIVarSplitNet Engine.