direct.nn package#
Subpackages#
- direct.nn.cirim package
- direct.nn.conjgradnet package
- Submodules
- direct.nn.conjgradnet.config module
ConjGradNetConfigConjGradNetConfig.num_stepsConjGradNetConfig.image_initConjGradNetConfig.no_parameter_sharingConjGradNetConfig.cg_tolConjGradNetConfig.cg_itersConjGradNetConfig.cg_param_update_typeConjGradNetConfig.denoiser_architectureConjGradNetConfig.resnet_hidden_channelsConjGradNetConfig.resnet_num_blocksConjGradNetConfig.resenet_batchnormConjGradNetConfig.resenet_scaleConjGradNetConfig.unet_num_filtersConjGradNetConfig.unet_num_pool_layersConjGradNetConfig.unet_dropoutConjGradNetConfig.didn_hidden_channelsConjGradNetConfig.didn_num_dubsConjGradNetConfig.didn_num_convs_reconConjGradNetConfig.conv_hidden_channelsConjGradNetConfig.conv_n_convsConjGradNetConfig.conv_activationConjGradNetConfig.conv_batchnormConjGradNetConfig.__init__()
- direct.nn.conjgradnet.conjgrad module
- direct.nn.conjgradnet.conjgradnet module
- direct.nn.conjgradnet.conjgradnet_engine module
- Module contents
- direct.nn.conv package
- direct.nn.crossdomain package
- direct.nn.didn package
- direct.nn.iterdualnet package
- Submodules
- direct.nn.iterdualnet.config module
IterDualNetConfigIterDualNetConfig.num_iterIterDualNetConfig.image_normunetIterDualNetConfig.kspace_normunetIterDualNetConfig.image_unet_num_filtersIterDualNetConfig.image_unet_num_pool_layersIterDualNetConfig.image_unet_dropoutIterDualNetConfig.kspace_unet_num_filtersIterDualNetConfig.kspace_unet_num_pool_layersIterDualNetConfig.kspace_unet_dropoutIterDualNetConfig.image_no_parameter_sharingIterDualNetConfig.kspace_no_parameter_sharingIterDualNetConfig.compute_per_coilIterDualNetConfig.__init__()
- direct.nn.iterdualnet.iterdualnet module
- direct.nn.iterdualnet.iterdualnet_engine module
- Module contents
- direct.nn.jointicnet package
- Submodules
- direct.nn.jointicnet.config module
JointICNetConfigJointICNetConfig.num_iterJointICNetConfig.use_norm_unetJointICNetConfig.image_unet_num_filtersJointICNetConfig.image_unet_num_pool_layersJointICNetConfig.image_unet_dropoutJointICNetConfig.kspace_unet_num_filtersJointICNetConfig.kspace_unet_num_pool_layersJointICNetConfig.kspace_unet_dropoutJointICNetConfig.sens_unet_num_filtersJointICNetConfig.sens_unet_num_pool_layersJointICNetConfig.sens_unet_dropoutJointICNetConfig.__init__()
- direct.nn.jointicnet.jointicnet module
- direct.nn.jointicnet.jointicnet_engine module
- Module contents
- direct.nn.kikinet package
- Submodules
- direct.nn.kikinet.config module
KIKINetConfigKIKINetConfig.num_iterKIKINetConfig.image_model_architectureKIKINetConfig.kspace_model_architectureKIKINetConfig.image_mwcnn_hidden_channelsKIKINetConfig.image_mwcnn_num_scalesKIKINetConfig.image_mwcnn_biasKIKINetConfig.image_mwcnn_batchnormKIKINetConfig.image_unet_num_filtersKIKINetConfig.image_unet_num_pool_layersKIKINetConfig.image_unet_dropout_probabilityKIKINetConfig.kspace_conv_hidden_channelsKIKINetConfig.kspace_conv_n_convsKIKINetConfig.kspace_conv_batchnormKIKINetConfig.kspace_didn_hidden_channelsKIKINetConfig.kspace_didn_num_dubsKIKINetConfig.kspace_didn_num_convs_reconKIKINetConfig.kspace_unet_num_filtersKIKINetConfig.kspace_unet_num_pool_layersKIKINetConfig.kspace_unet_dropout_probabilityKIKINetConfig.normalizeKIKINetConfig.__init__()
- direct.nn.kikinet.kikinet module
- direct.nn.kikinet.kikinet_engine module
- Module contents
- direct.nn.lpd package
- Submodules
- direct.nn.lpd.config module
LPDNetConfigLPDNetConfig.num_iterLPDNetConfig.num_primalLPDNetConfig.num_dualLPDNetConfig.primal_model_architectureLPDNetConfig.dual_model_architectureLPDNetConfig.primal_mwcnn_hidden_channelsLPDNetConfig.primal_mwcnn_num_scalesLPDNetConfig.primal_mwcnn_biasLPDNetConfig.primal_mwcnn_batchnormLPDNetConfig.primal_unet_num_filtersLPDNetConfig.primal_unet_num_pool_layersLPDNetConfig.primal_unet_dropout_probabilityLPDNetConfig.dual_conv_hidden_channelsLPDNetConfig.dual_conv_n_convsLPDNetConfig.dual_conv_batchnormLPDNetConfig.dual_didn_hidden_channelsLPDNetConfig.dual_didn_num_dubsLPDNetConfig.dual_didn_num_convs_reconLPDNetConfig.dual_unet_num_filtersLPDNetConfig.dual_unet_num_pool_layersLPDNetConfig.dual_unet_dropout_probabilityLPDNetConfig.__init__()
- direct.nn.lpd.lpd module
- direct.nn.lpd.lpd_engine module
- Module contents
- direct.nn.mobilenet package
- direct.nn.multidomainnet package
- Submodules
- direct.nn.multidomainnet.config module
- direct.nn.multidomainnet.multidomain module
- direct.nn.multidomainnet.multidomainnet module
- direct.nn.multidomainnet.multidomainnet_engine module
- Module contents
- direct.nn.mwcnn package
- direct.nn.recurrent package
- direct.nn.recurrentvarnet package
- Submodules
- direct.nn.recurrentvarnet.config module
RecurrentVarNetConfigRecurrentVarNetConfig.num_stepsRecurrentVarNetConfig.recurrent_hidden_channelsRecurrentVarNetConfig.recurrent_num_layersRecurrentVarNetConfig.no_parameter_sharingRecurrentVarNetConfig.learned_initializerRecurrentVarNetConfig.initializer_initializationRecurrentVarNetConfig.initializer_channelsRecurrentVarNetConfig.initializer_dilationsRecurrentVarNetConfig.initializer_multiscaleRecurrentVarNetConfig.normalizedRecurrentVarNetConfig.__init__()
- direct.nn.recurrentvarnet.recurrentvarnet module
- direct.nn.recurrentvarnet.recurrentvarnet_engine module
- Module contents
- direct.nn.resnet package
- direct.nn.rim package
- Submodules
- direct.nn.rim.config module
RIMConfigRIMConfig.hidden_channelsRIMConfig.lengthRIMConfig.depthRIMConfig.stepsRIMConfig.no_parameter_sharingRIMConfig.instance_normRIMConfig.dense_connectRIMConfig.whiten_inputRIMConfig.replication_paddingRIMConfig.image_initializationRIMConfig.scale_loglikelihoodRIMConfig.learned_initializerRIMConfig.initializer_channelsRIMConfig.initializer_dilationsRIMConfig.initializer_multiscaleRIMConfig.normalizedRIMConfig.__init__()
- direct.nn.rim.rim module
- direct.nn.rim.rim_engine module
- Module contents
- direct.nn.ssl package
- direct.nn.transformers package
- Submodules
- direct.nn.transformers.config module
UFormerModelConfigUFormerModelConfig.in_channelsUFormerModelConfig.out_channelsUFormerModelConfig.patch_sizeUFormerModelConfig.embedding_dimUFormerModelConfig.encoder_depthsUFormerModelConfig.encoder_num_headsUFormerModelConfig.bottleneck_depthUFormerModelConfig.bottleneck_num_headsUFormerModelConfig.win_sizeUFormerModelConfig.mlp_ratioUFormerModelConfig.qkv_biasUFormerModelConfig.qk_scaleUFormerModelConfig.drop_rateUFormerModelConfig.attn_drop_rateUFormerModelConfig.drop_path_rateUFormerModelConfig.patch_normUFormerModelConfig.token_projectionUFormerModelConfig.token_mlpUFormerModelConfig.shift_flagUFormerModelConfig.modulatorUFormerModelConfig.cross_modulatorUFormerModelConfig.normalizedUFormerModelConfig.__init__()
ImageDomainMRIUFormerConfigImageDomainMRIUFormerConfig.patch_sizeImageDomainMRIUFormerConfig.embedding_dimImageDomainMRIUFormerConfig.encoder_depthsImageDomainMRIUFormerConfig.encoder_num_headsImageDomainMRIUFormerConfig.bottleneck_depthImageDomainMRIUFormerConfig.bottleneck_num_headsImageDomainMRIUFormerConfig.win_sizeImageDomainMRIUFormerConfig.mlp_ratioImageDomainMRIUFormerConfig.qkv_biasImageDomainMRIUFormerConfig.qk_scaleImageDomainMRIUFormerConfig.drop_rateImageDomainMRIUFormerConfig.attn_drop_rateImageDomainMRIUFormerConfig.drop_path_rateImageDomainMRIUFormerConfig.patch_normImageDomainMRIUFormerConfig.token_projectionImageDomainMRIUFormerConfig.token_mlpImageDomainMRIUFormerConfig.shift_flagImageDomainMRIUFormerConfig.modulatorImageDomainMRIUFormerConfig.cross_modulatorImageDomainMRIUFormerConfig.normalizedImageDomainMRIUFormerConfig.__init__()
MRIViTConfigMRIViTConfig.embedding_dimMRIViTConfig.depthMRIViTConfig.num_headsMRIViTConfig.mlp_ratioMRIViTConfig.qkv_biasMRIViTConfig.qk_scaleMRIViTConfig.drop_rateMRIViTConfig.attn_drop_rateMRIViTConfig.dropout_path_rateMRIViTConfig.use_gpsaMRIViTConfig.locality_strengthMRIViTConfig.use_pos_embeddingMRIViTConfig.normalizedMRIViTConfig.__init__()
VisionTransformer2DConfigVisionTransformer3DConfigImageDomainMRIViT2DConfigImageDomainMRIViT3DConfigKSpaceDomainMRIViT2DConfigKSpaceDomainMRIViT3DConfig
- direct.nn.transformers.transformers module
- direct.nn.transformers.transformers_engine module
- direct.nn.transformers.uformer module
- direct.nn.transformers.utils module
- direct.nn.transformers.vit module
- Module contents
- direct.nn.unet package
- Submodules
- direct.nn.unet.config module
- direct.nn.unet.unet_2d module
- direct.nn.unet.unet_3d module
- direct.nn.unet.unet_engine module
- Module contents
- direct.nn.varnet package
- direct.nn.varsplitnet package
- Submodules
- direct.nn.varsplitnet.config module
MRIVarSplitNetConfigMRIVarSplitNetConfig.num_steps_regMRIVarSplitNetConfig.num_steps_dcMRIVarSplitNetConfig.image_initMRIVarSplitNetConfig.no_parameter_sharingMRIVarSplitNetConfig.kspace_no_parameter_sharingMRIVarSplitNetConfig.image_model_architectureMRIVarSplitNetConfig.kspace_model_architectureMRIVarSplitNetConfig.image_resnet_hidden_channelsMRIVarSplitNetConfig.image_resnet_num_blocksMRIVarSplitNetConfig.image_resnet_batchnormMRIVarSplitNetConfig.image_resnet_scaleMRIVarSplitNetConfig.image_unet_num_filtersMRIVarSplitNetConfig.image_unet_num_pool_layersMRIVarSplitNetConfig.image_unet_dropoutMRIVarSplitNetConfig.image_didn_hidden_channelsMRIVarSplitNetConfig.image_didn_num_dubsMRIVarSplitNetConfig.image_didn_num_convs_reconMRIVarSplitNetConfig.kspace_resnet_hidden_channelsMRIVarSplitNetConfig.kspace_resnet_num_blocksMRIVarSplitNetConfig.kspace_resnet_batchnormMRIVarSplitNetConfig.kspace_resnet_scaleMRIVarSplitNetConfig.kspace_unet_num_filtersMRIVarSplitNetConfig.kspace_unet_num_pool_layersMRIVarSplitNetConfig.kspace_unet_dropoutMRIVarSplitNetConfig.kspace_didn_hidden_channelsMRIVarSplitNetConfig.kspace_didn_num_dubsMRIVarSplitNetConfig.kspace_didn_num_convs_reconMRIVarSplitNetConfig.image_conv_hidden_channelsMRIVarSplitNetConfig.image_conv_n_convsMRIVarSplitNetConfig.image_conv_activationMRIVarSplitNetConfig.image_conv_batchnormMRIVarSplitNetConfig.kspace_conv_hidden_channelsMRIVarSplitNetConfig.kspace_conv_n_convsMRIVarSplitNetConfig.kspace_conv_activationMRIVarSplitNetConfig.kspace_conv_batchnormMRIVarSplitNetConfig.__init__()
- direct.nn.varsplitnet.varsplitnet module
- direct.nn.varsplitnet.varsplitnet_engine module
- Module contents
- direct.nn.vsharp package
- Submodules
- direct.nn.vsharp.config module
VSharpNetConfigVSharpNetConfig.num_stepsVSharpNetConfig.num_steps_dc_gdVSharpNetConfig.image_initVSharpNetConfig.no_parameter_sharingVSharpNetConfig.auxiliary_stepsVSharpNetConfig.image_model_architectureVSharpNetConfig.initializer_channelsVSharpNetConfig.initializer_dilationsVSharpNetConfig.initializer_multiscaleVSharpNetConfig.initializer_activationVSharpNetConfig.image_resnet_hidden_channelsVSharpNetConfig.image_resnet_num_blocksVSharpNetConfig.image_resnet_batchnormVSharpNetConfig.image_resnet_scaleVSharpNetConfig.image_unet_num_filtersVSharpNetConfig.image_unet_num_pool_layersVSharpNetConfig.image_unet_dropoutVSharpNetConfig.image_didn_hidden_channelsVSharpNetConfig.image_didn_num_dubsVSharpNetConfig.image_didn_num_convs_reconVSharpNetConfig.image_conv_hidden_channelsVSharpNetConfig.image_conv_n_convsVSharpNetConfig.image_conv_activationVSharpNetConfig.image_conv_batchnormVSharpNetConfig.__init__()
VSharpNet3DConfigVSharpNet3DConfig.num_stepsVSharpNet3DConfig.num_steps_dc_gdVSharpNet3DConfig.image_initVSharpNet3DConfig.no_parameter_sharingVSharpNet3DConfig.auxiliary_stepsVSharpNet3DConfig.initializer_channelsVSharpNet3DConfig.initializer_dilationsVSharpNet3DConfig.initializer_multiscaleVSharpNet3DConfig.initializer_activationVSharpNet3DConfig.unet_num_filtersVSharpNet3DConfig.unet_num_pool_layersVSharpNet3DConfig.unet_dropoutVSharpNet3DConfig.unet_normVSharpNet3DConfig.__init__()
- direct.nn.vsharp.vsharp module
- direct.nn.vsharp.vsharp_engine module
- Module contents
- direct.nn.xpdnet package
- Submodules
- direct.nn.xpdnet.config module
XPDNetConfigXPDNetConfig.num_primalXPDNetConfig.num_dualXPDNetConfig.num_iterXPDNetConfig.use_primal_onlyXPDNetConfig.kspace_model_architectureXPDNetConfig.dual_conv_hidden_channelsXPDNetConfig.dual_conv_n_convsXPDNetConfig.dual_conv_batchnormXPDNetConfig.dual_didn_hidden_channelsXPDNetConfig.dual_didn_num_dubsXPDNetConfig.dual_didn_num_convs_reconXPDNetConfig.mwcnn_hidden_channelsXPDNetConfig.mwcnn_num_scalesXPDNetConfig.mwcnn_biasXPDNetConfig.mwcnn_batchnormXPDNetConfig.normalizeXPDNetConfig.__init__()
- direct.nn.xpdnet.xpdnet module
- direct.nn.xpdnet.xpdnet_engine module
- Module contents
Submodules#
direct.nn.get_nn_model_config module#
direct.nn.get_nn_model_config module.
direct.nn.mri_models module#
MRI model engine of DIRECT.
- class direct.nn.mri_models.MRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
EngineEngine for MRI models.
Each child class should implement their own
forward_function().- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
MRIModelEngine.- Parameters:
cfg (
BaseConfig) – Configuration object.model (
Module) – Model.device (
str) – Device. Can be"cuda:{idx}"or"cpu".forward_operator (
Optional[Callable]) – The forward operator. Default:None.backward_operator (
Optional[Callable]) – The backward operator. Default:None.mixed_precision (
bool) – Use mixed precision. Default:False.**models (
Module) – Additional models.
- forward_function(data)[source]#
This method performs the model’s forward method given data which contains all tensor inputs.
Must be implemented by child classes.
- Return type:
Tuple[Optional[Tensor],Optional[Tensor]]
- compute_sensitivity_map(sensitivity_map)[source]#
Computes sensitivity maps \(\{S^k\}_{k=1}^{n_c}\) if sensitivity_model is available.
\(\{S^k\}_{k=1}^{n_c}\) are normalized such that
\[\sum_{k=1}^{n_c}S^k {S^k}^* = I.\]- Parameters:
sensitivity_map (
Tensor) – Sensitivity maps of shape (batch, coil, height, width, complex=2).- Return type:
Tensor- Returns:
Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2).
- reconstruct_volumes(data_loader, loss_fns=None, regularizer_fns=None, add_target=True, crop=None)[source]#
Validation process. Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
data_loader (
DataLoader) – Data loader.loss_fns (
Optional[Dict[str,Callable]]) – Callable loss functions.regularizer_fns (
Optional[Dict[str,Callable]]) – Callable regularization functions.add_target (
bool) – If true, will add the target to the output.crop (
Optional[str]) – Crop type.
- Yields:
(curr_volume, [curr_target,] loss_dict_list, filename) – torch.Tensor, [torch.Tensor,], dict, pathlib.Path
- evaluate(data_loader, loss_fns)[source]#
Validation process.
Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
data_loader (
DataLoader) – Data loader.loss_fns (
Optional[Dict[str,Callable]]) – Callable loss functions.
- Returns:
loss_dict, all_gathered_metrics, visualize_slices, visualize_target
direct.nn.types module#
direct.nn.types module.
- class direct.nn.types.ActivationType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnum- RELU = 'relu'#
- PRELU = 'prelu'#
- LEAKY_RELU = 'leaky_relu'#
- class direct.nn.types.ModelName(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnum- UNET = 'unet'#
- NORMUNET = 'normunet'#
- RESNET = 'resnet'#
- DIDN = 'didn'#
- CONV = 'conv'#
- class direct.nn.types.InitType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnum- INPUT_IMAGE = 'input_image'#
- SENSE = 'sense'#
- ZERO_FILLED = 'zero_filled'#
- ZEROS = 'zeros'#
- class direct.nn.types.LossFunType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnum- L1_LOSS = 'l1_loss'#
- KSPACE_L1_LOSS = 'kspace_l1_loss'#
- L2_LOSS = 'l2_loss'#
- KSPACE_L2_LOSS = 'kspace_l2_loss'#
- SSIM_LOSS = 'ssim_loss'#
- SSIM_3D_LOSS = 'ssim_3d_loss'#
- GRAD_L1_LOSS = 'grad_l1_loss'#
- GRAD_L2_LOSS = 'grad_l2_loss'#
- NMSE_LOSS = 'nmse_loss'#
- KSPACE_NMSE_LOSS = 'kspace_nmse_loss'#
- NRMSE_LOSS = 'nrmse_loss'#
- KSPACE_NRMSE_LOSS = 'kspace_nrmse_loss'#
- NMAE_LOSS = 'nmae_loss'#
- KSPACE_NMAE_LOSS = 'kspace_nmae_loss'#
- SNR_LOSS = 'snr_loss'#
- PSNR_LOSS = 'psnr_loss'#
- HFEN_L1_LOSS = 'hfen_l1_loss'#
- HFEN_L2_LOSS = 'hfen_l2_loss'#
- HFEN_L1_NORM_LOSS = 'hfen_l1_norm_loss'#
- HFEN_L2_NORM_LOSS = 'hfen_l2_norm_loss'#