direct.nn.lpd package#

Submodules#

direct.nn.lpd.config module#

class direct.nn.lpd.config.LPDNetConfig(model_name='???', engine_name=None, num_iter=25, num_primal=5, num_dual=5, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', primal_mwcnn_hidden_channels=16, primal_mwcnn_num_scales=4, primal_mwcnn_bias=True, primal_mwcnn_batchnorm=False, primal_unet_num_filters=8, primal_unet_num_pool_layers=4, primal_unet_dropout_probability=0.0, dual_conv_hidden_channels=16, dual_conv_n_convs=4, dual_conv_batchnorm=False, dual_didn_hidden_channels=64, dual_didn_num_dubs=6, dual_didn_num_convs_recon=9, dual_unet_num_filters=8, dual_unet_num_pool_layers=4, dual_unet_dropout_probability=0.0)[source]#

Bases: ModelConfig

num_iter = 25#
num_primal = 5#
num_dual = 5#
primal_model_architecture = 'MWCNN'#
dual_model_architecture = 'DIDN'#
primal_mwcnn_hidden_channels = 16#
primal_mwcnn_num_scales = 4#
primal_mwcnn_bias = True#
primal_mwcnn_batchnorm = False#
primal_unet_num_filters = 8#
primal_unet_num_pool_layers = 4#
primal_unet_dropout_probability = 0.0#
dual_conv_hidden_channels = 16#
dual_conv_n_convs = 4#
dual_conv_batchnorm = False#
dual_didn_hidden_channels = 64#
dual_didn_num_dubs = 6#
dual_didn_num_convs_recon = 9#
dual_unet_num_filters = 8#
dual_unet_num_pool_layers = 4#
dual_unet_dropout_probability = 0.0#
__init__(model_name='???', engine_name=None, num_iter=25, num_primal=5, num_dual=5, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', primal_mwcnn_hidden_channels=16, primal_mwcnn_num_scales=4, primal_mwcnn_bias=True, primal_mwcnn_batchnorm=False, primal_unet_num_filters=8, primal_unet_num_pool_layers=4, primal_unet_dropout_probability=0.0, dual_conv_hidden_channels=16, dual_conv_n_convs=4, dual_conv_batchnorm=False, dual_didn_hidden_channels=64, dual_didn_num_dubs=6, dual_didn_num_convs_recon=9, dual_unet_num_filters=8, dual_unet_num_pool_layers=4, dual_unet_dropout_probability=0.0)#

direct.nn.lpd.lpd module#

class direct.nn.lpd.lpd.DualNet(num_dual, **kwargs)[source]#

Bases: Module

Dual Network for Learned Primal Dual Network.

__init__(num_dual, **kwargs)[source]#

Inits DualNet.

Parameters:
  • num_dual (int) – Number of dual for LPD algorithm.

  • **kwargs – Keyword arguments.

static compute_model_per_coil(model, data)[source]#

Computes model per coil.

Parameters:
  • model (Module) – Model to compute.

  • data (Tensor) – Multi-coil input.

Return type:

Tensor

Returns:

Multi-coil output.

forward(h, forward_f, g)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class direct.nn.lpd.lpd.PrimalNet(num_primal, **kwargs)[source]#

Bases: Module

Primal Network for Learned Primal Dual Network.

__init__(num_primal, **kwargs)[source]#

Inits PrimalNet.

Parameters:
  • num_primal (int) – Number of primal for LPD algorithm.

  • **kwargs – Keyword arguments.

forward(f, backward_h)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class direct.nn.lpd.lpd.LPDNet(forward_operator, backward_operator, num_iter, num_primal, num_dual, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', **kwargs)[source]#

Bases: Module

Learned Primal Dual network implementation inspired by [1].

References:

__init__(forward_operator, backward_operator, num_iter, num_primal, num_dual, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', **kwargs)[source]#

Inits LPDNet.

Parameters:
  • forward_operator (Callable) – Callable

  • Operator. (Backward)

  • backward_operator (Callable) – Callable

  • Operator.

  • num_iter (int) – int

  • iterations. (Number of unrolled)

  • num_primal (int) – int

  • networks. (Number of dual)

  • num_dual (int) – int

  • networks.

  • primal_model_architecture (str) – str

  • and (Dual model architecture. Currently only implemented for CONV and DIDN) – ‘MWCNN’.

  • dual_model_architecture (str) – str

  • and – ‘DIDN’.

  • kwargs – dict

  • architectures. (Keyword arguments for model)

forward(masked_kspace, sensitivity_map, sampling_mask)[source]#

Computes forward pass of LPDNet.

Parameters:
  • masked_kspace (Tensor) – torch.Tensor

  • shape (Sampling mask of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • sampling_mask (Tensor) – torch.Tensor

  • shape

Returns:

torch.Tensor Output image of shape (N, height, width, complex=2).

Return type:

output

direct.nn.lpd.lpd_engine module#

class direct.nn.lpd.lpd_engine.LPDNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

LPDNet Engine.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits :class:`LPDNetEngine.

forward_function(data)[source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#