direct.nn.varsplitnet package#
Submodules#
direct.nn.varsplitnet.config module#
- class direct.nn.varsplitnet.config.MRIVarSplitNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps_reg: int = 8, num_steps_dc: int = 8, image_init: str = 'sense', no_parameter_sharing: bool = True, kspace_no_parameter_sharing: bool = True, image_model_architecture: str = <ModelName.UNET: 'unet'>, kspace_model_architecture: Optional[str] = None, image_resnet_hidden_channels: Optional[int] = 128, image_resnet_num_blocks: Optional[int] = 15, image_resnet_batchnorm: Optional[bool] = True, image_resnet_scale: Optional[float] = 0.1, image_unet_num_filters: Optional[int] = 32, image_unet_num_pool_layers: Optional[int] = 4, image_unet_dropout: Optional[float] = 0.0, image_didn_hidden_channels: Optional[int] = 16, image_didn_num_dubs: Optional[int] = 6, image_didn_num_convs_recon: Optional[int] = 9, kspace_resnet_hidden_channels: Optional[int] = 64, kspace_resnet_num_blocks: Optional[int] = 1, kspace_resnet_batchnorm: Optional[bool] = True, kspace_resnet_scale: Optional[float] = 0.1, kspace_unet_num_filters: Optional[int] = 16, kspace_unet_num_pool_layers: Optional[int] = 4, kspace_unet_dropout: Optional[float] = 0.0, kspace_didn_hidden_channels: Optional[int] = 8, kspace_didn_num_dubs: Optional[int] = 6, kspace_didn_num_convs_recon: Optional[int] = 9, image_conv_hidden_channels: Optional[int] = 64, image_conv_n_convs: Optional[int] = 15, image_conv_activation: Optional[str] = <ActivationType.RELU: 'relu'>, image_conv_batchnorm: Optional[bool] = False, kspace_conv_hidden_channels: Optional[int] = 64, kspace_conv_n_convs: Optional[int] = 15, kspace_conv_activation: Optional[str] = <ActivationType.PRELU: 'prelu'>, kspace_conv_batchnorm: Optional[bool] = False)[source][source]#
- Bases: - ModelConfig- 
image_conv_activation: Optional[str] = 'relu'#
 - 
image_conv_batchnorm: Optional[bool] = False#
 - 
image_conv_n_convs: Optional[int] = 15#
 - 
image_didn_num_convs_recon: Optional[int] = 9#
 - 
image_didn_num_dubs: Optional[int] = 6#
 - 
image_init: str= 'sense'#
 - 
image_model_architecture: str= 'unet'#
 - 
image_resnet_batchnorm: Optional[bool] = True#
 - 
image_resnet_num_blocks: Optional[int] = 15#
 - 
image_resnet_scale: Optional[float] = 0.1#
 - 
image_unet_dropout: Optional[float] = 0.0#
 - 
image_unet_num_filters: Optional[int] = 32#
 - 
image_unet_num_pool_layers: Optional[int] = 4#
 - 
kspace_conv_activation: Optional[str] = 'prelu'#
 - 
kspace_conv_batchnorm: Optional[bool] = False#
 - 
kspace_conv_n_convs: Optional[int] = 15#
 - 
kspace_didn_num_convs_recon: Optional[int] = 9#
 - 
kspace_didn_num_dubs: Optional[int] = 6#
 - 
kspace_model_architecture: Optional[str] = None#
 - 
kspace_no_parameter_sharing: bool= True#
 - 
kspace_resnet_batchnorm: Optional[bool] = True#
 - 
kspace_resnet_num_blocks: Optional[int] = 1#
 - 
kspace_resnet_scale: Optional[float] = 0.1#
 - 
kspace_unet_dropout: Optional[float] = 0.0#
 - 
kspace_unet_num_filters: Optional[int] = 16#
 - 
kspace_unet_num_pool_layers: Optional[int] = 4#
 - 
no_parameter_sharing: bool= True#
 - 
num_steps_dc: int= 8#
 - 
num_steps_reg: int= 8#
 
- 
image_conv_activation: 
direct.nn.varsplitnet.varsplitnet module#
- class direct.nn.varsplitnet.varsplitnet.MRIVarSplitNet(forward_operator, backward_operator, num_steps_reg, num_steps_dc, image_init=InitType.SENSE, no_parameter_sharing=True, image_model_architecture=ModelName.UNET, kspace_no_parameter_sharing=True, kspace_model_architecture=None, **kwargs)[source][source]#
- Bases: - Module- MRI reconstruction network that solves the variable split optimisation problem. - It solves the following: \[ \begin{align}\begin{aligned}z^{i-1} = \arg \min_{z} \mu * ||x^{i-1} - z||_2^2 + \mathcal{R}(z)\\x^{i} = \arg \min_{x} ||y - A(x)||_2^2 + \mu * ||x - z^{i-1}||_2^2\end{aligned}\end{align} \]- by unrolling twice using the gradient descent algorithm and replacing \(R\) with a neural network. More specifically, for \(z_0, x_0 = \text{SENSE}(\tilde{y})\): \[z^{i} = \alpha_{i-1} \times f_{\theta_{i-1}}\Big(\mu(z^{i-1} - x^{i-1}), z^{i-1}\Big), \quad i=1,\cdots,T_{reg}\]- where \(x^{i}\) is the output of \[(x^{i})^{j} = (x^{i})^{j-1} - \beta_{j-1} \Big[ A^{*}\big( A( (x^{i})^{j-1} ) - \tilde{y} \big) + \mu ((x^{i})^{j-1} - z^{i}) \Big], \quad j=1,\cdots,T_{dc},\]- i.e. \(x^{i}=(x^{i}^{T_{reg}})\). - compute_model_per_coil(model, data)[source][source]#
- Performs forward pass of model per coil. - Parameters:
- model: nn.Module
- Model to run. 
- data: torch.Tensor
- Multi-coil data of shape (batch, coil, complex=2, height, width). 
 
- Returns:
- output: torch.Tensor
- Computed output per coil. 
 
- Return type:
- Tensor
 
 - forward(masked_kspace, sensitivity_map, sampling_mask, scaling_factor=None)[source][source]#
- Computes forward pass of - MRIVarSplitNet.- Parameters:
- masked_kspace: torch.Tensor
- Masked k-space of shape (N, coil, height, width, complex=2). 
- sensitivity_map: torch.Tensor
- Sensitivity map of shape (N, coil, height, width, complex=2). Default: None. 
- sampling_mask: torch.Tensor
- scaling_factor: torch.Tensor
 
- Returns:
- image: torch.Tensor
- Output image of shape (N, height, width, complex=2). 
 
- Return type:
- Tensor
 
 
direct.nn.varsplitnet.varsplitnet_engine module#
- class direct.nn.varsplitnet.varsplitnet_engine.MRIVarSplitNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
- Bases: - MRIModelEngine- MRIVarSplitNet Engine.