direct.nn package#
Subpackages#
- direct.nn.cirim package
- direct.nn.conjgradnet package
- Submodules
- direct.nn.conjgradnet.config module
ConjGradNetConfig
ConjGradNetConfig.cg_iters
ConjGradNetConfig.cg_param_update_type
ConjGradNetConfig.cg_tol
ConjGradNetConfig.conv_activation
ConjGradNetConfig.conv_batchnorm
ConjGradNetConfig.conv_hidden_channels
ConjGradNetConfig.conv_n_convs
ConjGradNetConfig.denoiser_architecture
ConjGradNetConfig.didn_hidden_channels
ConjGradNetConfig.didn_num_convs_recon
ConjGradNetConfig.didn_num_dubs
ConjGradNetConfig.image_init
ConjGradNetConfig.no_parameter_sharing
ConjGradNetConfig.num_steps
ConjGradNetConfig.resenet_batchnorm
ConjGradNetConfig.resenet_scale
ConjGradNetConfig.resnet_hidden_channels
ConjGradNetConfig.resnet_num_blocks
ConjGradNetConfig.unet_dropout
ConjGradNetConfig.unet_num_filters
ConjGradNetConfig.unet_num_pool_layers
- direct.nn.conjgradnet.conjgrad module
- direct.nn.conjgradnet.conjgradnet module
- direct.nn.conjgradnet.conjgradnet_engine module
- Module contents
- direct.nn.conv package
- direct.nn.crossdomain package
- direct.nn.didn package
- direct.nn.iterdualnet package
- Submodules
- direct.nn.iterdualnet.config module
IterDualNetConfig
IterDualNetConfig.compute_per_coil
IterDualNetConfig.image_no_parameter_sharing
IterDualNetConfig.image_normunet
IterDualNetConfig.image_unet_dropout
IterDualNetConfig.image_unet_num_filters
IterDualNetConfig.image_unet_num_pool_layers
IterDualNetConfig.kspace_no_parameter_sharing
IterDualNetConfig.kspace_normunet
IterDualNetConfig.kspace_unet_dropout
IterDualNetConfig.kspace_unet_num_filters
IterDualNetConfig.kspace_unet_num_pool_layers
IterDualNetConfig.num_iter
- direct.nn.iterdualnet.iterdualnet module
- direct.nn.iterdualnet.iterdualnet_engine module
- Module contents
- direct.nn.jointicnet package
- Submodules
- direct.nn.jointicnet.config module
JointICNetConfig
JointICNetConfig.image_unet_dropout
JointICNetConfig.image_unet_num_filters
JointICNetConfig.image_unet_num_pool_layers
JointICNetConfig.kspace_unet_dropout
JointICNetConfig.kspace_unet_num_filters
JointICNetConfig.kspace_unet_num_pool_layers
JointICNetConfig.num_iter
JointICNetConfig.sens_unet_dropout
JointICNetConfig.sens_unet_num_filters
JointICNetConfig.sens_unet_num_pool_layers
JointICNetConfig.use_norm_unet
- direct.nn.jointicnet.jointicnet module
- direct.nn.jointicnet.jointicnet_engine module
- Module contents
- direct.nn.kikinet package
- Submodules
- direct.nn.kikinet.config module
KIKINetConfig
KIKINetConfig.image_model_architecture
KIKINetConfig.image_mwcnn_batchnorm
KIKINetConfig.image_mwcnn_bias
KIKINetConfig.image_mwcnn_hidden_channels
KIKINetConfig.image_mwcnn_num_scales
KIKINetConfig.image_unet_dropout_probability
KIKINetConfig.image_unet_num_filters
KIKINetConfig.image_unet_num_pool_layers
KIKINetConfig.kspace_conv_batchnorm
KIKINetConfig.kspace_conv_hidden_channels
KIKINetConfig.kspace_conv_n_convs
KIKINetConfig.kspace_didn_hidden_channels
KIKINetConfig.kspace_didn_num_convs_recon
KIKINetConfig.kspace_didn_num_dubs
KIKINetConfig.kspace_model_architecture
KIKINetConfig.kspace_unet_dropout_probability
KIKINetConfig.kspace_unet_num_filters
KIKINetConfig.kspace_unet_num_pool_layers
KIKINetConfig.normalize
KIKINetConfig.num_iter
- direct.nn.kikinet.kikinet module
- direct.nn.kikinet.kikinet_engine module
- Module contents
- direct.nn.lpd package
- Submodules
- direct.nn.lpd.config module
LPDNetConfig
LPDNetConfig.dual_conv_batchnorm
LPDNetConfig.dual_conv_hidden_channels
LPDNetConfig.dual_conv_n_convs
LPDNetConfig.dual_didn_hidden_channels
LPDNetConfig.dual_didn_num_convs_recon
LPDNetConfig.dual_didn_num_dubs
LPDNetConfig.dual_model_architecture
LPDNetConfig.dual_unet_dropout_probability
LPDNetConfig.dual_unet_num_filters
LPDNetConfig.dual_unet_num_pool_layers
LPDNetConfig.num_dual
LPDNetConfig.num_iter
LPDNetConfig.num_primal
LPDNetConfig.primal_model_architecture
LPDNetConfig.primal_mwcnn_batchnorm
LPDNetConfig.primal_mwcnn_bias
LPDNetConfig.primal_mwcnn_hidden_channels
LPDNetConfig.primal_mwcnn_num_scales
LPDNetConfig.primal_unet_dropout_probability
LPDNetConfig.primal_unet_num_filters
LPDNetConfig.primal_unet_num_pool_layers
- direct.nn.lpd.lpd module
- direct.nn.lpd.lpd_engine module
- Module contents
- direct.nn.mobilenet package
- direct.nn.multidomainnet package
- Submodules
- direct.nn.multidomainnet.config module
- direct.nn.multidomainnet.multidomain module
- direct.nn.multidomainnet.multidomainnet module
- direct.nn.multidomainnet.multidomainnet_engine module
- Module contents
- direct.nn.mwcnn package
- direct.nn.recurrent package
- direct.nn.recurrentvarnet package
- Submodules
- direct.nn.recurrentvarnet.config module
RecurrentVarNetConfig
RecurrentVarNetConfig.initializer_channels
RecurrentVarNetConfig.initializer_dilations
RecurrentVarNetConfig.initializer_initialization
RecurrentVarNetConfig.initializer_multiscale
RecurrentVarNetConfig.learned_initializer
RecurrentVarNetConfig.no_parameter_sharing
RecurrentVarNetConfig.normalized
RecurrentVarNetConfig.num_steps
RecurrentVarNetConfig.recurrent_hidden_channels
RecurrentVarNetConfig.recurrent_num_layers
- direct.nn.recurrentvarnet.recurrentvarnet module
- direct.nn.recurrentvarnet.recurrentvarnet_engine module
- Module contents
- direct.nn.resnet package
- direct.nn.rim package
- Submodules
- direct.nn.rim.config module
RIMConfig
RIMConfig.dense_connect
RIMConfig.depth
RIMConfig.hidden_channels
RIMConfig.image_initialization
RIMConfig.initializer_channels
RIMConfig.initializer_dilations
RIMConfig.initializer_multiscale
RIMConfig.instance_norm
RIMConfig.learned_initializer
RIMConfig.length
RIMConfig.no_parameter_sharing
RIMConfig.normalized
RIMConfig.replication_padding
RIMConfig.scale_loglikelihood
RIMConfig.steps
RIMConfig.whiten_input
- direct.nn.rim.rim module
- direct.nn.rim.rim_engine module
- Module contents
- direct.nn.ssl package
- direct.nn.unet package
- Submodules
- direct.nn.unet.config module
- direct.nn.unet.unet_2d module
- direct.nn.unet.unet_3d module
- direct.nn.unet.unet_engine module
- Module contents
- direct.nn.varnet package
- direct.nn.varsplitnet package
- Submodules
- direct.nn.varsplitnet.config module
MRIVarSplitNetConfig
MRIVarSplitNetConfig.image_conv_activation
MRIVarSplitNetConfig.image_conv_batchnorm
MRIVarSplitNetConfig.image_conv_hidden_channels
MRIVarSplitNetConfig.image_conv_n_convs
MRIVarSplitNetConfig.image_didn_hidden_channels
MRIVarSplitNetConfig.image_didn_num_convs_recon
MRIVarSplitNetConfig.image_didn_num_dubs
MRIVarSplitNetConfig.image_init
MRIVarSplitNetConfig.image_model_architecture
MRIVarSplitNetConfig.image_resnet_batchnorm
MRIVarSplitNetConfig.image_resnet_hidden_channels
MRIVarSplitNetConfig.image_resnet_num_blocks
MRIVarSplitNetConfig.image_resnet_scale
MRIVarSplitNetConfig.image_unet_dropout
MRIVarSplitNetConfig.image_unet_num_filters
MRIVarSplitNetConfig.image_unet_num_pool_layers
MRIVarSplitNetConfig.kspace_conv_activation
MRIVarSplitNetConfig.kspace_conv_batchnorm
MRIVarSplitNetConfig.kspace_conv_hidden_channels
MRIVarSplitNetConfig.kspace_conv_n_convs
MRIVarSplitNetConfig.kspace_didn_hidden_channels
MRIVarSplitNetConfig.kspace_didn_num_convs_recon
MRIVarSplitNetConfig.kspace_didn_num_dubs
MRIVarSplitNetConfig.kspace_model_architecture
MRIVarSplitNetConfig.kspace_no_parameter_sharing
MRIVarSplitNetConfig.kspace_resnet_batchnorm
MRIVarSplitNetConfig.kspace_resnet_hidden_channels
MRIVarSplitNetConfig.kspace_resnet_num_blocks
MRIVarSplitNetConfig.kspace_resnet_scale
MRIVarSplitNetConfig.kspace_unet_dropout
MRIVarSplitNetConfig.kspace_unet_num_filters
MRIVarSplitNetConfig.kspace_unet_num_pool_layers
MRIVarSplitNetConfig.no_parameter_sharing
MRIVarSplitNetConfig.num_steps_dc
MRIVarSplitNetConfig.num_steps_reg
- direct.nn.varsplitnet.varsplitnet module
- direct.nn.varsplitnet.varsplitnet_engine module
- Module contents
- direct.nn.vsharp package
- Submodules
- direct.nn.vsharp.config module
VSharpNet3DConfig
VSharpNet3DConfig.auxiliary_steps
VSharpNet3DConfig.image_init
VSharpNet3DConfig.initializer_activation
VSharpNet3DConfig.initializer_channels
VSharpNet3DConfig.initializer_dilations
VSharpNet3DConfig.initializer_multiscale
VSharpNet3DConfig.no_parameter_sharing
VSharpNet3DConfig.num_steps
VSharpNet3DConfig.num_steps_dc_gd
VSharpNet3DConfig.unet_dropout
VSharpNet3DConfig.unet_norm
VSharpNet3DConfig.unet_num_filters
VSharpNet3DConfig.unet_num_pool_layers
VSharpNetConfig
VSharpNetConfig.auxiliary_steps
VSharpNetConfig.image_conv_activation
VSharpNetConfig.image_conv_batchnorm
VSharpNetConfig.image_conv_hidden_channels
VSharpNetConfig.image_conv_n_convs
VSharpNetConfig.image_didn_hidden_channels
VSharpNetConfig.image_didn_num_convs_recon
VSharpNetConfig.image_didn_num_dubs
VSharpNetConfig.image_init
VSharpNetConfig.image_model_architecture
VSharpNetConfig.image_resnet_batchnorm
VSharpNetConfig.image_resnet_hidden_channels
VSharpNetConfig.image_resnet_num_blocks
VSharpNetConfig.image_resnet_scale
VSharpNetConfig.image_unet_dropout
VSharpNetConfig.image_unet_num_filters
VSharpNetConfig.image_unet_num_pool_layers
VSharpNetConfig.initializer_activation
VSharpNetConfig.initializer_channels
VSharpNetConfig.initializer_dilations
VSharpNetConfig.initializer_multiscale
VSharpNetConfig.no_parameter_sharing
VSharpNetConfig.num_steps
VSharpNetConfig.num_steps_dc_gd
- direct.nn.vsharp.vsharp module
- direct.nn.vsharp.vsharp_engine module
- Module contents
- direct.nn.xpdnet package
- Submodules
- direct.nn.xpdnet.config module
XPDNetConfig
XPDNetConfig.dual_conv_batchnorm
XPDNetConfig.dual_conv_hidden_channels
XPDNetConfig.dual_conv_n_convs
XPDNetConfig.dual_didn_hidden_channels
XPDNetConfig.dual_didn_num_convs_recon
XPDNetConfig.dual_didn_num_dubs
XPDNetConfig.kspace_model_architecture
XPDNetConfig.mwcnn_batchnorm
XPDNetConfig.mwcnn_bias
XPDNetConfig.mwcnn_hidden_channels
XPDNetConfig.mwcnn_num_scales
XPDNetConfig.normalize
XPDNetConfig.num_dual
XPDNetConfig.num_iter
XPDNetConfig.num_primal
XPDNetConfig.use_primal_only
- direct.nn.xpdnet.xpdnet module
- direct.nn.xpdnet.xpdnet_engine module
- Module contents
Submodules#
direct.nn.get_nn_model_config module#
direct.nn.get_nn_model_config module.
direct.nn.mri_models module#
MRI model engine of DIRECT.
- class direct.nn.mri_models.MRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
Engine
Engine for MRI models.
Each child class should implement their own
forward_function()
.- compute_loss_on_data(loss_dict, loss_fns, data, output_image=None, output_kspace=None, weight=1.0)[source][source]#
- Return type:
Dict
[str
,Tensor
]
- compute_model_per_coil(model_name, data)[source][source]#
Performs forward pass of model model_name in self.models per coil.
- Parameters:
- model_name: str
Model to run.
- data: torch.Tensor
Multi-coil data of shape (batch, coil, complex=2, height, width).
- Returns:
- output: torch.Tensor
Computed output per coil.
- Return type:
Tensor
- compute_sensitivity_map(sensitivity_map)[source][source]#
Computes sensitivity maps \(\{S^k\}_{k=1}^{n_c}\) if sensitivity_model is available.
\(\{S^k\}_{k=1}^{n_c}\) are normalized such that :rtype:
Tensor
\[\sum_{k=1}^{n_c}S^k {S^k}^* = I.\]- Parameters:
- sensitivity_map: torch.Tensor
Sensitivity maps of shape (batch, coil, height, width, complex=2).
- Returns:
- sensitivity_map: torch.Tensor
Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2).
- evaluate(data_loader, loss_fns)[source][source]#
Validation process.
Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
- data_loader: DataLoader
- loss_fns: Dict[str, Callable], optional
- Returns:
- loss_dict, all_gathered_metrics, visualize_slices, visualize_target
- forward_function(data)[source][source]#
This method performs the model’s forward method given data which contains all tensor inputs.
Must be implemented by child classes.
- Return type:
Tuple
[Optional
[Tensor
],Optional
[Tensor
]]
- reconstruct_volumes(data_loader, loss_fns=None, regularizer_fns=None, add_target=True, crop=None)[source][source]#
Validation process. Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
- data_loader: DataLoader
- loss_fns: Dict[str, Callable], optional
- regularizer_fns: Dict[str, Callable], optional
- add_target: bool
If true, will add the target to the output
- crop: str, optional
Crop type.
- Yields:
- (curr_volume, [curr_target,] loss_dict_list, filename): torch.Tensor, [torch.Tensor,], dict, pathlib.Path
direct.nn.types module#
direct.nn.types module.
- class direct.nn.types.ActivationType(value)[source][source]#
Bases:
DirectEnum
An enumeration.
- LEAKY_RELU = 'leaky_relu'#
- PRELU = 'prelu'#
- RELU = 'relu'#
- class direct.nn.types.InitType(value)[source][source]#
Bases:
DirectEnum
An enumeration.
- INPUT_IMAGE = 'input_image'#
- SENSE = 'sense'#
- ZEROS = 'zeros'#
- ZERO_FILLED = 'zero_filled'#
- class direct.nn.types.LossFunType(value)[source][source]#
Bases:
DirectEnum
An enumeration.
- GRAD_L1_LOSS = 'grad_l1_loss'#
- GRAD_L2_LOSS = 'grad_l2_loss'#
- HFEN_L1_LOSS = 'hfen_l1_loss'#
- HFEN_L1_NORM_LOSS = 'hfen_l1_norm_loss'#
- HFEN_L2_LOSS = 'hfen_l2_loss'#
- HFEN_L2_NORM_LOSS = 'hfen_l2_norm_loss'#
- KSPACE_L1_LOSS = 'kspace_l1_loss'#
- KSPACE_L2_LOSS = 'kspace_l2_loss'#
- KSPACE_NMAE_LOSS = 'kspace_nmae_loss'#
- KSPACE_NMSE_LOSS = 'kspace_nmse_loss'#
- KSPACE_NRMSE_LOSS = 'kspace_nrmse_loss'#
- L1_LOSS = 'l1_loss'#
- L2_LOSS = 'l2_loss'#
- NMAE_LOSS = 'nmae_loss'#
- NMSE_LOSS = 'nmse_loss'#
- NRMSE_LOSS = 'nrmse_loss'#
- PSNR_LOSS = 'psnr_loss'#
- SNR_LOSS = 'snr_loss'#
- SSIM_3D_LOSS = 'ssim_3d_loss'#
- SSIM_LOSS = 'ssim_loss'#