direct.nn.iterdualnet package#
Submodules#
direct.nn.iterdualnet.config module#
- class direct.nn.iterdualnet.config.IterDualNetConfig(model_name='???', engine_name=None, num_iter=10, image_normunet=False, kspace_normunet=False, image_unet_num_filters=8, image_unet_num_pool_layers=4, image_unet_dropout=0.0, kspace_unet_num_filters=8, kspace_unet_num_pool_layers=4, kspace_unet_dropout=0.0, image_no_parameter_sharing=True, kspace_no_parameter_sharing=False, compute_per_coil=True)[source]#
Bases:
ModelConfig- num_iter = 10#
- image_normunet = False#
- kspace_normunet = False#
- image_unet_num_filters = 8#
- image_unet_num_pool_layers = 4#
- image_unet_dropout = 0.0#
- kspace_unet_num_filters = 8#
- kspace_unet_num_pool_layers = 4#
- kspace_unet_dropout = 0.0#
- image_no_parameter_sharing = True#
- kspace_no_parameter_sharing = False#
- compute_per_coil = True#
- __init__(model_name='???', engine_name=None, num_iter=10, image_normunet=False, kspace_normunet=False, image_unet_num_filters=8, image_unet_num_pool_layers=4, image_unet_dropout=0.0, kspace_unet_num_filters=8, kspace_unet_num_pool_layers=4, kspace_unet_dropout=0.0, image_no_parameter_sharing=True, kspace_no_parameter_sharing=False, compute_per_coil=True)#
direct.nn.iterdualnet.iterdualnet module#
- class direct.nn.iterdualnet.iterdualnet.IterDualNet(forward_operator, backward_operator, num_iter=10, image_normunet=False, kspace_normunet=False, image_no_parameter_sharing=True, kspace_no_parameter_sharing=True, compute_per_coil=True, **kwargs)[source]#
Bases:
ModuleIterative Dual Network solves iteratively the following problem
\[\begin{split}\min_{x} ||A(x) - y||_2^2 + \lambda_I ||x - D_I(x)||_2^2 + \lambda_F ||x - \mathcal{Q}(D_F(f))||_2^2, \quad \left\{ \begin{array} Q = \mathcal{F}^{-1}, f = \mathcal{F}(x) & \text{if compute_per_coil is False} \\ Q = \mathcal{F}^{-1} \circ \mathcal{E}, f = \mathcal{R} \circ \mathcal{F}(x) & \text{otherwise} \end{array}\end{split}\]by unrolling a gradient descent scheme where \(\mathcal{E}\) and \(\mathcal{R}\) are the expand and reduce operators which use the sensitivity maps. \(D_I\) and \(D_F\) are trainable U-Nets operating in the image and k-space domain.
- __init__(forward_operator, backward_operator, num_iter=10, image_normunet=False, kspace_normunet=False, image_no_parameter_sharing=True, kspace_no_parameter_sharing=True, compute_per_coil=True, **kwargs)[source]#
Inits
IterDualNet.- Parameters:
forward_operator (
Callable) – Forward Operator.backward_operator (
Callable) – Backward Operator.num_iter (
int) – Number of iterations. Default:10.image_normunet (
bool) – IfTruewill useNormUNetfor the image model. Default:False.kspace_normunet (
bool) – IfTruewill useNormUNetfor the kspace model. Default:False.image_no_parameter_sharing (
bool) – IfFalse, a single image model will be shared across all iterations. Default:True.kspace_no_parameter_sharing (
bool) – IfFalse, a single kspace model will be shared across all iterations. Default:True.compute_per_coil (
bool) – IfTrue\(f\) will be transformed into a multi-coil kspace. Default:True.**kwargs – Keyword arguments for unet models.
- forward(masked_kspace, sampling_mask, sensitivity_map)[source]#
Computes forward pass of
IterDualNet.- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Sensitivity map of)
sampling_mask (
Tensor) – torch.Tensorshape
sensitivity_map (
Tensor) – torch.Tensorshape
- Returns:
torch.Tensor Output image of shape (N, height, width, complex=2).
- Return type:
out_image
direct.nn.iterdualnet.iterdualnet_engine module#
- class direct.nn.iterdualnet.iterdualnet_engine.IterDualNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
MRIModelEngine- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
IterDualNetEngine.- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.