direct.nn.iterdualnet package#

Submodules#

direct.nn.iterdualnet.config module#

class direct.nn.iterdualnet.config.IterDualNetConfig(model_name: str = '???', engine_name: str | None = None, num_iter: int = 10, image_normunet: bool = False, kspace_normunet: bool = False, image_unet_num_filters: int = 8, image_unet_num_pool_layers: int = 4, image_unet_dropout: float = 0.0, kspace_unet_num_filters: int = 8, kspace_unet_num_pool_layers: int = 4, kspace_unet_dropout: float = 0.0, image_no_parameter_sharing: bool = True, kspace_no_parameter_sharing: bool = False, compute_per_coil: bool = True)[source][source]#

Bases: ModelConfig

compute_per_coil: bool = True#
image_no_parameter_sharing: bool = True#
image_normunet: bool = False#
image_unet_dropout: float = 0.0#
image_unet_num_filters: int = 8#
image_unet_num_pool_layers: int = 4#
kspace_no_parameter_sharing: bool = False#
kspace_normunet: bool = False#
kspace_unet_dropout: float = 0.0#
kspace_unet_num_filters: int = 8#
kspace_unet_num_pool_layers: int = 4#
num_iter: int = 10#

direct.nn.iterdualnet.iterdualnet module#

class direct.nn.iterdualnet.iterdualnet.IterDualNet(forward_operator, backward_operator, num_iter=10, image_normunet=False, kspace_normunet=False, image_no_parameter_sharing=True, kspace_no_parameter_sharing=True, compute_per_coil=True, **kwargs)[source][source]#

Bases: Module

Iterative Dual Network solves iteratively the following problem

\[\begin{split}\min_{x} ||A(x) - y||_2^2 + \lambda_I ||x - D_I(x)||_2^2 + \lambda_F ||x - \mathcal{Q}(D_F(f))||_2^2, \quad \left\{ \begin{array} Q = \mathcal{F}^{-1}, f = \mathcal{F}(x) & \text{if compute_per_coil is False} \\ Q = \mathcal{F}^{-1} \circ \mathcal{E}, f = \mathcal{R} \circ \mathcal{F}(x) & \text{otherwise} \end{array}\end{split}\]

by unrolling a gradient descent scheme where \(\mathcal{E}\) and \(\mathcal{R}\) are the expand and reduce operators which use the sensitivity maps. \(D_I\) and \(D_F\) are trainable U-Nets operating in the image and k-space domain.

forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#

Computes forward pass of IterDualNet.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

Returns:
out_image: torch.Tensor

Output image of shape (N, height, width, complex=2).

Return type:

Tensor

training: bool#

direct.nn.iterdualnet.iterdualnet_engine module#

class direct.nn.iterdualnet.iterdualnet_engine.IterDualNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#