direct.nn.iterdualnet package#
Submodules#
direct.nn.iterdualnet.config module#
- class direct.nn.iterdualnet.config.IterDualNetConfig(model_name: str = '???', engine_name: str | None = None, num_iter: int = 10, image_normunet: bool = False, kspace_normunet: bool = False, image_unet_num_filters: int = 8, image_unet_num_pool_layers: int = 4, image_unet_dropout: float = 0.0, kspace_unet_num_filters: int = 8, kspace_unet_num_pool_layers: int = 4, kspace_unet_dropout: float = 0.0, image_no_parameter_sharing: bool = True, kspace_no_parameter_sharing: bool = False, compute_per_coil: bool = True)[source][source]#
Bases:
ModelConfig
-
compute_per_coil:
bool
= True#
-
image_no_parameter_sharing:
bool
= True#
-
image_normunet:
bool
= False#
-
image_unet_dropout:
float
= 0.0#
-
image_unet_num_filters:
int
= 8#
-
image_unet_num_pool_layers:
int
= 4#
-
kspace_no_parameter_sharing:
bool
= False#
-
kspace_normunet:
bool
= False#
-
kspace_unet_dropout:
float
= 0.0#
-
kspace_unet_num_filters:
int
= 8#
-
kspace_unet_num_pool_layers:
int
= 4#
-
num_iter:
int
= 10#
-
compute_per_coil:
direct.nn.iterdualnet.iterdualnet module#
- class direct.nn.iterdualnet.iterdualnet.IterDualNet(forward_operator, backward_operator, num_iter=10, image_normunet=False, kspace_normunet=False, image_no_parameter_sharing=True, kspace_no_parameter_sharing=True, compute_per_coil=True, **kwargs)[source][source]#
Bases:
Module
Iterative Dual Network solves iteratively the following problem
\[\begin{split}\min_{x} ||A(x) - y||_2^2 + \lambda_I ||x - D_I(x)||_2^2 + \lambda_F ||x - \mathcal{Q}(D_F(f))||_2^2, \quad \left\{ \begin{array} Q = \mathcal{F}^{-1}, f = \mathcal{F}(x) & \text{if compute_per_coil is False} \\ Q = \mathcal{F}^{-1} \circ \mathcal{E}, f = \mathcal{R} \circ \mathcal{F}(x) & \text{otherwise} \end{array}\end{split}\]by unrolling a gradient descent scheme where \(\mathcal{E}\) and \(\mathcal{R}\) are the expand and reduce operators which use the sensitivity maps. \(D_I\) and \(D_F\) are trainable U-Nets operating in the image and k-space domain.
- forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#
Computes forward pass of
IterDualNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- Returns:
- out_image: torch.Tensor
Output image of shape (N, height, width, complex=2).
- Return type:
Tensor
-
training:
bool
#
direct.nn.iterdualnet.iterdualnet_engine module#
- class direct.nn.iterdualnet.iterdualnet_engine.IterDualNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine