direct.nn.xpdnet package#

Submodules#

direct.nn.xpdnet.config module#

class direct.nn.xpdnet.config.XPDNetConfig(model_name: str = '???', engine_name: str | None = None, num_primal: int = 5, num_dual: int = 1, num_iter: int = 10, use_primal_only: bool = True, kspace_model_architecture: str = 'CONV', dual_conv_hidden_channels: int = 16, dual_conv_n_convs: int = 4, dual_conv_batchnorm: bool = False, dual_didn_hidden_channels: int = 64, dual_didn_num_dubs: int = 6, dual_didn_num_convs_recon: int = 9, mwcnn_hidden_channels: int = 16, mwcnn_num_scales: int = 4, mwcnn_bias: bool = True, mwcnn_batchnorm: bool = False, normalize: bool = False)[source][source]#

Bases: ModelConfig

dual_conv_batchnorm: bool = False#
dual_conv_hidden_channels: int = 16#
dual_conv_n_convs: int = 4#
dual_didn_hidden_channels: int = 64#
dual_didn_num_convs_recon: int = 9#
dual_didn_num_dubs: int = 6#
kspace_model_architecture: str = 'CONV'#
mwcnn_batchnorm: bool = False#
mwcnn_bias: bool = True#
mwcnn_hidden_channels: int = 16#
mwcnn_num_scales: int = 4#
normalize: bool = False#
num_dual: int = 1#
num_iter: int = 10#
num_primal: int = 5#
use_primal_only: bool = True#

direct.nn.xpdnet.xpdnet module#

class direct.nn.xpdnet.xpdnet.XPDNet(forward_operator, backward_operator, num_primal=5, num_dual=1, num_iter=10, use_primal_only=True, image_model_architecture='MWCNN', kspace_model_architecture=None, normalize=False, **kwargs)[source][source]#

Bases: CrossDomainNetwork

XPDNet as implemented in [1].

References

[1]

Ramzi, Zaccharie, et al. “XPDNet for MRI Reconstruction: An Application to the 2020 FastMRI Challenge.” ArXiv:2010.07290 [Physics, Stat], July 2021. arXiv.org, http://arxiv.org/abs/2010.07290.

training: bool#

direct.nn.xpdnet.xpdnet_engine module#

class direct.nn.xpdnet.xpdnet_engine.XPDNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

XPDNet Engine.

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#