direct.nn.varsplitnet package

Contents

direct.nn.varsplitnet package#

Submodules#

direct.nn.varsplitnet.config module#

class direct.nn.varsplitnet.config.MRIVarSplitNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps_reg: int = 8, num_steps_dc: int = 8, image_init: str = 'sense', no_parameter_sharing: bool = True, kspace_no_parameter_sharing: bool = True, image_model_architecture: str = <ModelName.UNET: 'unet'>, kspace_model_architecture: Optional[str] = None, image_resnet_hidden_channels: Optional[int] = 128, image_resnet_num_blocks: Optional[int] = 15, image_resnet_batchnorm: Optional[bool] = True, image_resnet_scale: Optional[float] = 0.1, image_unet_num_filters: Optional[int] = 32, image_unet_num_pool_layers: Optional[int] = 4, image_unet_dropout: Optional[float] = 0.0, image_didn_hidden_channels: Optional[int] = 16, image_didn_num_dubs: Optional[int] = 6, image_didn_num_convs_recon: Optional[int] = 9, kspace_resnet_hidden_channels: Optional[int] = 64, kspace_resnet_num_blocks: Optional[int] = 1, kspace_resnet_batchnorm: Optional[bool] = True, kspace_resnet_scale: Optional[float] = 0.1, kspace_unet_num_filters: Optional[int] = 16, kspace_unet_num_pool_layers: Optional[int] = 4, kspace_unet_dropout: Optional[float] = 0.0, kspace_didn_hidden_channels: Optional[int] = 8, kspace_didn_num_dubs: Optional[int] = 6, kspace_didn_num_convs_recon: Optional[int] = 9, image_conv_hidden_channels: Optional[int] = 64, image_conv_n_convs: Optional[int] = 15, image_conv_activation: Optional[str] = <ActivationType.RELU: 'relu'>, image_conv_batchnorm: Optional[bool] = False, kspace_conv_hidden_channels: Optional[int] = 64, kspace_conv_n_convs: Optional[int] = 15, kspace_conv_activation: Optional[str] = <ActivationType.PRELU: 'prelu'>, kspace_conv_batchnorm: Optional[bool] = False)[source][source]#

Bases: ModelConfig

image_conv_activation: Optional[str] = 'relu'#
image_conv_batchnorm: Optional[bool] = False#
image_conv_hidden_channels: Optional[int] = 64#
image_conv_n_convs: Optional[int] = 15#
image_didn_hidden_channels: Optional[int] = 16#
image_didn_num_convs_recon: Optional[int] = 9#
image_didn_num_dubs: Optional[int] = 6#
image_init: str = 'sense'#
image_model_architecture: str = 'unet'#
image_resnet_batchnorm: Optional[bool] = True#
image_resnet_hidden_channels: Optional[int] = 128#
image_resnet_num_blocks: Optional[int] = 15#
image_resnet_scale: Optional[float] = 0.1#
image_unet_dropout: Optional[float] = 0.0#
image_unet_num_filters: Optional[int] = 32#
image_unet_num_pool_layers: Optional[int] = 4#
kspace_conv_activation: Optional[str] = 'prelu'#
kspace_conv_batchnorm: Optional[bool] = False#
kspace_conv_hidden_channels: Optional[int] = 64#
kspace_conv_n_convs: Optional[int] = 15#
kspace_didn_hidden_channels: Optional[int] = 8#
kspace_didn_num_convs_recon: Optional[int] = 9#
kspace_didn_num_dubs: Optional[int] = 6#
kspace_model_architecture: Optional[str] = None#
kspace_no_parameter_sharing: bool = True#
kspace_resnet_batchnorm: Optional[bool] = True#
kspace_resnet_hidden_channels: Optional[int] = 64#
kspace_resnet_num_blocks: Optional[int] = 1#
kspace_resnet_scale: Optional[float] = 0.1#
kspace_unet_dropout: Optional[float] = 0.0#
kspace_unet_num_filters: Optional[int] = 16#
kspace_unet_num_pool_layers: Optional[int] = 4#
no_parameter_sharing: bool = True#
num_steps_dc: int = 8#
num_steps_reg: int = 8#

direct.nn.varsplitnet.varsplitnet module#

class direct.nn.varsplitnet.varsplitnet.MRIVarSplitNet(forward_operator, backward_operator, num_steps_reg, num_steps_dc, image_init=InitType.SENSE, no_parameter_sharing=True, image_model_architecture=ModelName.UNET, kspace_no_parameter_sharing=True, kspace_model_architecture=None, **kwargs)[source][source]#

Bases: Module

MRI reconstruction network that solves the variable split optimisation problem.

It solves the following:

\[ \begin{align}\begin{aligned}z^{i-1} = \arg \min_{z} \mu * ||x^{i-1} - z||_2^2 + \mathcal{R}(z)\\x^{i} = \arg \min_{x} ||y - A(x)||_2^2 + \mu * ||x - z^{i-1}||_2^2\end{aligned}\end{align} \]

by unrolling twice using the gradient descent algorithm and replacing \(R\) with a neural network. More specifically, for \(z_0, x_0 = \text{SENSE}(\tilde{y})\):

\[z^{i} = \alpha_{i-1} \times f_{\theta_{i-1}}\Big(\mu(z^{i-1} - x^{i-1}), z^{i-1}\Big), \quad i=1,\cdots,T_{reg}\]

where \(x^{i}\) is the output of

\[(x^{i})^{j} = (x^{i})^{j-1} - \beta_{j-1} \Big[ A^{*}\big( A( (x^{i})^{j-1} ) - \tilde{y} \big) + \mu ((x^{i})^{j-1} - z^{i}) \Big], \quad j=1,\cdots,T_{dc},\]

i.e. \(x^{i}=(x^{i}^{T_{reg}})\).

compute_model_per_coil(model, data)[source][source]#

Performs forward pass of model per coil.

Parameters:
model: nn.Module

Model to run.

data: torch.Tensor

Multi-coil data of shape (batch, coil, complex=2, height, width).

Returns:
output: torch.Tensor

Computed output per coil.

Return type:

Tensor

forward(masked_kspace, sensitivity_map, sampling_mask, scaling_factor=None)[source][source]#

Computes forward pass of MRIVarSplitNet.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.

sampling_mask: torch.Tensor
scaling_factor: torch.Tensor
Returns:
image: torch.Tensor

Output image of shape (N, height, width, complex=2).

Return type:

Tensor

training: bool#

direct.nn.varsplitnet.varsplitnet_engine module#

class direct.nn.varsplitnet.varsplitnet_engine.MRIVarSplitNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

MRIVarSplitNet Engine.

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#