direct.nn package#
Subpackages#
- direct.nn.cirim package
- direct.nn.conjgradnet package
- Submodules
- direct.nn.conjgradnet.config module
ConjGradNetConfigConjGradNetConfig.cg_itersConjGradNetConfig.cg_param_update_typeConjGradNetConfig.cg_tolConjGradNetConfig.conv_activationConjGradNetConfig.conv_batchnormConjGradNetConfig.conv_hidden_channelsConjGradNetConfig.conv_n_convsConjGradNetConfig.denoiser_architectureConjGradNetConfig.didn_hidden_channelsConjGradNetConfig.didn_num_convs_reconConjGradNetConfig.didn_num_dubsConjGradNetConfig.image_initConjGradNetConfig.no_parameter_sharingConjGradNetConfig.num_stepsConjGradNetConfig.resenet_batchnormConjGradNetConfig.resenet_scaleConjGradNetConfig.resnet_hidden_channelsConjGradNetConfig.resnet_num_blocksConjGradNetConfig.unet_dropoutConjGradNetConfig.unet_num_filtersConjGradNetConfig.unet_num_pool_layers
- direct.nn.conjgradnet.conjgrad module
- direct.nn.conjgradnet.conjgradnet module
- direct.nn.conjgradnet.conjgradnet_engine module
- Module contents
- direct.nn.conv package
- direct.nn.crossdomain package
- direct.nn.didn package
- direct.nn.iterdualnet package
- Submodules
- direct.nn.iterdualnet.config module
IterDualNetConfigIterDualNetConfig.compute_per_coilIterDualNetConfig.image_no_parameter_sharingIterDualNetConfig.image_normunetIterDualNetConfig.image_unet_dropoutIterDualNetConfig.image_unet_num_filtersIterDualNetConfig.image_unet_num_pool_layersIterDualNetConfig.kspace_no_parameter_sharingIterDualNetConfig.kspace_normunetIterDualNetConfig.kspace_unet_dropoutIterDualNetConfig.kspace_unet_num_filtersIterDualNetConfig.kspace_unet_num_pool_layersIterDualNetConfig.num_iter
- direct.nn.iterdualnet.iterdualnet module
- direct.nn.iterdualnet.iterdualnet_engine module
- Module contents
- direct.nn.jointicnet package
- Submodules
- direct.nn.jointicnet.config module
JointICNetConfigJointICNetConfig.image_unet_dropoutJointICNetConfig.image_unet_num_filtersJointICNetConfig.image_unet_num_pool_layersJointICNetConfig.kspace_unet_dropoutJointICNetConfig.kspace_unet_num_filtersJointICNetConfig.kspace_unet_num_pool_layersJointICNetConfig.num_iterJointICNetConfig.sens_unet_dropoutJointICNetConfig.sens_unet_num_filtersJointICNetConfig.sens_unet_num_pool_layersJointICNetConfig.use_norm_unet
- direct.nn.jointicnet.jointicnet module
- direct.nn.jointicnet.jointicnet_engine module
- Module contents
- direct.nn.kikinet package
- Submodules
- direct.nn.kikinet.config module
KIKINetConfigKIKINetConfig.image_model_architectureKIKINetConfig.image_mwcnn_batchnormKIKINetConfig.image_mwcnn_biasKIKINetConfig.image_mwcnn_hidden_channelsKIKINetConfig.image_mwcnn_num_scalesKIKINetConfig.image_unet_dropout_probabilityKIKINetConfig.image_unet_num_filtersKIKINetConfig.image_unet_num_pool_layersKIKINetConfig.kspace_conv_batchnormKIKINetConfig.kspace_conv_hidden_channelsKIKINetConfig.kspace_conv_n_convsKIKINetConfig.kspace_didn_hidden_channelsKIKINetConfig.kspace_didn_num_convs_reconKIKINetConfig.kspace_didn_num_dubsKIKINetConfig.kspace_model_architectureKIKINetConfig.kspace_unet_dropout_probabilityKIKINetConfig.kspace_unet_num_filtersKIKINetConfig.kspace_unet_num_pool_layersKIKINetConfig.normalizeKIKINetConfig.num_iter
- direct.nn.kikinet.kikinet module
- direct.nn.kikinet.kikinet_engine module
- Module contents
- direct.nn.lpd package
- Submodules
- direct.nn.lpd.config module
LPDNetConfigLPDNetConfig.dual_conv_batchnormLPDNetConfig.dual_conv_hidden_channelsLPDNetConfig.dual_conv_n_convsLPDNetConfig.dual_didn_hidden_channelsLPDNetConfig.dual_didn_num_convs_reconLPDNetConfig.dual_didn_num_dubsLPDNetConfig.dual_model_architectureLPDNetConfig.dual_unet_dropout_probabilityLPDNetConfig.dual_unet_num_filtersLPDNetConfig.dual_unet_num_pool_layersLPDNetConfig.num_dualLPDNetConfig.num_iterLPDNetConfig.num_primalLPDNetConfig.primal_model_architectureLPDNetConfig.primal_mwcnn_batchnormLPDNetConfig.primal_mwcnn_biasLPDNetConfig.primal_mwcnn_hidden_channelsLPDNetConfig.primal_mwcnn_num_scalesLPDNetConfig.primal_unet_dropout_probabilityLPDNetConfig.primal_unet_num_filtersLPDNetConfig.primal_unet_num_pool_layers
- direct.nn.lpd.lpd module
- direct.nn.lpd.lpd_engine module
- Module contents
- direct.nn.mobilenet package
- direct.nn.multidomainnet package
- direct.nn.mwcnn package
- direct.nn.recurrent package
- direct.nn.recurrentvarnet package
- Submodules
- direct.nn.recurrentvarnet.config module
RecurrentVarNetConfigRecurrentVarNetConfig.initializer_channelsRecurrentVarNetConfig.initializer_dilationsRecurrentVarNetConfig.initializer_initializationRecurrentVarNetConfig.initializer_multiscaleRecurrentVarNetConfig.learned_initializerRecurrentVarNetConfig.no_parameter_sharingRecurrentVarNetConfig.normalizedRecurrentVarNetConfig.num_stepsRecurrentVarNetConfig.recurrent_hidden_channelsRecurrentVarNetConfig.recurrent_num_layers
- direct.nn.recurrentvarnet.recurrentvarnet module
- direct.nn.recurrentvarnet.recurrentvarnet_engine module
- Module contents
- direct.nn.resnet package
- direct.nn.rim package
- Submodules
- direct.nn.rim.config module
RIMConfigRIMConfig.dense_connectRIMConfig.depthRIMConfig.hidden_channelsRIMConfig.image_initializationRIMConfig.initializer_channelsRIMConfig.initializer_dilationsRIMConfig.initializer_multiscaleRIMConfig.instance_normRIMConfig.learned_initializerRIMConfig.lengthRIMConfig.no_parameter_sharingRIMConfig.normalizedRIMConfig.replication_paddingRIMConfig.scale_loglikelihoodRIMConfig.stepsRIMConfig.whiten_input
- direct.nn.rim.rim module
- direct.nn.rim.rim_engine module
- Module contents
- direct.nn.ssl package
- direct.nn.transformers package
- Submodules
- direct.nn.transformers.config module
ImageDomainMRIUFormerConfigImageDomainMRIUFormerConfig.attn_drop_rateImageDomainMRIUFormerConfig.bottleneck_depthImageDomainMRIUFormerConfig.bottleneck_num_headsImageDomainMRIUFormerConfig.cross_modulatorImageDomainMRIUFormerConfig.drop_path_rateImageDomainMRIUFormerConfig.drop_rateImageDomainMRIUFormerConfig.embedding_dimImageDomainMRIUFormerConfig.encoder_depthsImageDomainMRIUFormerConfig.encoder_num_headsImageDomainMRIUFormerConfig.mlp_ratioImageDomainMRIUFormerConfig.modulatorImageDomainMRIUFormerConfig.normalizedImageDomainMRIUFormerConfig.patch_normImageDomainMRIUFormerConfig.patch_sizeImageDomainMRIUFormerConfig.qk_scaleImageDomainMRIUFormerConfig.qkv_biasImageDomainMRIUFormerConfig.shift_flagImageDomainMRIUFormerConfig.token_mlpImageDomainMRIUFormerConfig.token_projectionImageDomainMRIUFormerConfig.win_size
ImageDomainMRIViT2DConfigImageDomainMRIViT3DConfigKSpaceDomainMRIViT2DConfigKSpaceDomainMRIViT3DConfigMRIViTConfigMRIViTConfig.attn_drop_rateMRIViTConfig.depthMRIViTConfig.drop_rateMRIViTConfig.dropout_path_rateMRIViTConfig.embedding_dimMRIViTConfig.locality_strengthMRIViTConfig.mlp_ratioMRIViTConfig.normalizedMRIViTConfig.num_headsMRIViTConfig.qk_scaleMRIViTConfig.qkv_biasMRIViTConfig.use_gpsaMRIViTConfig.use_pos_embedding
UFormerModelConfigUFormerModelConfig.attn_drop_rateUFormerModelConfig.bottleneck_depthUFormerModelConfig.bottleneck_num_headsUFormerModelConfig.cross_modulatorUFormerModelConfig.drop_path_rateUFormerModelConfig.drop_rateUFormerModelConfig.embedding_dimUFormerModelConfig.encoder_depthsUFormerModelConfig.encoder_num_headsUFormerModelConfig.in_channelsUFormerModelConfig.mlp_ratioUFormerModelConfig.modulatorUFormerModelConfig.normalizedUFormerModelConfig.out_channelsUFormerModelConfig.patch_normUFormerModelConfig.patch_sizeUFormerModelConfig.qk_scaleUFormerModelConfig.qkv_biasUFormerModelConfig.shift_flagUFormerModelConfig.token_mlpUFormerModelConfig.token_projectionUFormerModelConfig.win_size
VisionTransformer2DConfigVisionTransformer3DConfig
- direct.nn.transformers.transformers module
- direct.nn.transformers.transformers_engine module
- direct.nn.transformers.uformer module
- direct.nn.transformers.utils module
- direct.nn.transformers.vit module
- Module contents
- direct.nn.unet package
- direct.nn.varnet package
- direct.nn.varsplitnet package
- Submodules
- direct.nn.varsplitnet.config module
MRIVarSplitNetConfigMRIVarSplitNetConfig.image_conv_activationMRIVarSplitNetConfig.image_conv_batchnormMRIVarSplitNetConfig.image_conv_hidden_channelsMRIVarSplitNetConfig.image_conv_n_convsMRIVarSplitNetConfig.image_didn_hidden_channelsMRIVarSplitNetConfig.image_didn_num_convs_reconMRIVarSplitNetConfig.image_didn_num_dubsMRIVarSplitNetConfig.image_initMRIVarSplitNetConfig.image_model_architectureMRIVarSplitNetConfig.image_resnet_batchnormMRIVarSplitNetConfig.image_resnet_hidden_channelsMRIVarSplitNetConfig.image_resnet_num_blocksMRIVarSplitNetConfig.image_resnet_scaleMRIVarSplitNetConfig.image_unet_dropoutMRIVarSplitNetConfig.image_unet_num_filtersMRIVarSplitNetConfig.image_unet_num_pool_layersMRIVarSplitNetConfig.kspace_conv_activationMRIVarSplitNetConfig.kspace_conv_batchnormMRIVarSplitNetConfig.kspace_conv_hidden_channelsMRIVarSplitNetConfig.kspace_conv_n_convsMRIVarSplitNetConfig.kspace_didn_hidden_channelsMRIVarSplitNetConfig.kspace_didn_num_convs_reconMRIVarSplitNetConfig.kspace_didn_num_dubsMRIVarSplitNetConfig.kspace_model_architectureMRIVarSplitNetConfig.kspace_no_parameter_sharingMRIVarSplitNetConfig.kspace_resnet_batchnormMRIVarSplitNetConfig.kspace_resnet_hidden_channelsMRIVarSplitNetConfig.kspace_resnet_num_blocksMRIVarSplitNetConfig.kspace_resnet_scaleMRIVarSplitNetConfig.kspace_unet_dropoutMRIVarSplitNetConfig.kspace_unet_num_filtersMRIVarSplitNetConfig.kspace_unet_num_pool_layersMRIVarSplitNetConfig.no_parameter_sharingMRIVarSplitNetConfig.num_steps_dcMRIVarSplitNetConfig.num_steps_reg
- direct.nn.varsplitnet.varsplitnet module
- direct.nn.varsplitnet.varsplitnet_engine module
- Module contents
- direct.nn.vsharp package
- Submodules
- direct.nn.vsharp.config module
VSharpNet3DConfigVSharpNet3DConfig.auxiliary_stepsVSharpNet3DConfig.image_initVSharpNet3DConfig.initializer_activationVSharpNet3DConfig.initializer_channelsVSharpNet3DConfig.initializer_dilationsVSharpNet3DConfig.initializer_multiscaleVSharpNet3DConfig.no_parameter_sharingVSharpNet3DConfig.num_stepsVSharpNet3DConfig.num_steps_dc_gdVSharpNet3DConfig.unet_dropoutVSharpNet3DConfig.unet_normVSharpNet3DConfig.unet_num_filtersVSharpNet3DConfig.unet_num_pool_layers
VSharpNetConfigVSharpNetConfig.auxiliary_stepsVSharpNetConfig.image_conv_activationVSharpNetConfig.image_conv_batchnormVSharpNetConfig.image_conv_hidden_channelsVSharpNetConfig.image_conv_n_convsVSharpNetConfig.image_didn_hidden_channelsVSharpNetConfig.image_didn_num_convs_reconVSharpNetConfig.image_didn_num_dubsVSharpNetConfig.image_initVSharpNetConfig.image_model_architectureVSharpNetConfig.image_resnet_batchnormVSharpNetConfig.image_resnet_hidden_channelsVSharpNetConfig.image_resnet_num_blocksVSharpNetConfig.image_resnet_scaleVSharpNetConfig.image_unet_dropoutVSharpNetConfig.image_unet_num_filtersVSharpNetConfig.image_unet_num_pool_layersVSharpNetConfig.initializer_activationVSharpNetConfig.initializer_channelsVSharpNetConfig.initializer_dilationsVSharpNetConfig.initializer_multiscaleVSharpNetConfig.no_parameter_sharingVSharpNetConfig.num_stepsVSharpNetConfig.num_steps_dc_gd
- direct.nn.vsharp.vsharp module
- direct.nn.vsharp.vsharp_engine module
- Module contents
- direct.nn.xpdnet package
- Submodules
- direct.nn.xpdnet.config module
XPDNetConfigXPDNetConfig.dual_conv_batchnormXPDNetConfig.dual_conv_hidden_channelsXPDNetConfig.dual_conv_n_convsXPDNetConfig.dual_didn_hidden_channelsXPDNetConfig.dual_didn_num_convs_reconXPDNetConfig.dual_didn_num_dubsXPDNetConfig.kspace_model_architectureXPDNetConfig.mwcnn_batchnormXPDNetConfig.mwcnn_biasXPDNetConfig.mwcnn_hidden_channelsXPDNetConfig.mwcnn_num_scalesXPDNetConfig.normalizeXPDNetConfig.num_dualXPDNetConfig.num_iterXPDNetConfig.num_primalXPDNetConfig.use_primal_only
- direct.nn.xpdnet.xpdnet module
- direct.nn.xpdnet.xpdnet_engine module
- Module contents
Submodules#
direct.nn.get_nn_model_config module#
direct.nn.get_nn_model_config module.
direct.nn.mri_models module#
MRI model engine of DIRECT.
- class direct.nn.mri_models.MRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
EngineEngine for MRI models.
Each child class should implement their own
forward_function().- compute_loss_on_data(loss_dict, loss_fns, data, output_image=None, output_kspace=None, weight=1.0)[source][source]#
- Return type:
Dict[str,Tensor]
- compute_model_per_coil(model_name, data)[source][source]#
Performs forward pass of model model_name in self.models per coil.
- Parameters:
- model_name: str
Model to run.
- data: torch.Tensor
Multi-coil data of shape (batch, coil, complex=2, height, width).
- Returns:
- output: torch.Tensor
Computed output per coil.
- Return type:
Tensor
- compute_sensitivity_map(sensitivity_map)[source][source]#
Computes sensitivity maps \(\{S^k\}_{k=1}^{n_c}\) if sensitivity_model is available.
\(\{S^k\}_{k=1}^{n_c}\) are normalized such that :rtype:
Tensor\[\sum_{k=1}^{n_c}S^k {S^k}^* = I.\]- Parameters:
- sensitivity_map: torch.Tensor
Sensitivity maps of shape (batch, coil, height, width, complex=2).
- Returns:
- sensitivity_map: torch.Tensor
Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2).
- evaluate(data_loader, loss_fns)[source][source]#
Validation process.
Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
- data_loader: DataLoader
- loss_fns: Dict[str, Callable], optional
- Returns:
- loss_dict, all_gathered_metrics, visualize_slices, visualize_target
- forward_function(data)[source][source]#
This method performs the model’s forward method given data which contains all tensor inputs.
Must be implemented by child classes.
- Return type:
Tuple[Optional[Tensor],Optional[Tensor]]
- reconstruct_volumes(data_loader, loss_fns=None, regularizer_fns=None, add_target=True, crop=None)[source][source]#
Validation process. Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.
- Parameters:
- data_loader: DataLoader
- loss_fns: Dict[str, Callable], optional
- regularizer_fns: Dict[str, Callable], optional
- add_target: bool
If true, will add the target to the output
- crop: str, optional
Crop type.
- Yields:
- (curr_volume, [curr_target,] loss_dict_list, filename): torch.Tensor, [torch.Tensor,], dict, pathlib.Path
direct.nn.types module#
direct.nn.types module.
- class direct.nn.types.ActivationType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- LEAKY_RELU = 'leaky_relu'#
- PRELU = 'prelu'#
- RELU = 'relu'#
- class direct.nn.types.InitType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- INPUT_IMAGE = 'input_image'#
- SENSE = 'sense'#
- ZEROS = 'zeros'#
- ZERO_FILLED = 'zero_filled'#
- class direct.nn.types.LossFunType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum- GRAD_L1_LOSS = 'grad_l1_loss'#
- GRAD_L2_LOSS = 'grad_l2_loss'#
- HFEN_L1_LOSS = 'hfen_l1_loss'#
- HFEN_L1_NORM_LOSS = 'hfen_l1_norm_loss'#
- HFEN_L2_LOSS = 'hfen_l2_loss'#
- HFEN_L2_NORM_LOSS = 'hfen_l2_norm_loss'#
- KSPACE_L1_LOSS = 'kspace_l1_loss'#
- KSPACE_L2_LOSS = 'kspace_l2_loss'#
- KSPACE_NMAE_LOSS = 'kspace_nmae_loss'#
- KSPACE_NMSE_LOSS = 'kspace_nmse_loss'#
- KSPACE_NRMSE_LOSS = 'kspace_nrmse_loss'#
- L1_LOSS = 'l1_loss'#
- L2_LOSS = 'l2_loss'#
- NMAE_LOSS = 'nmae_loss'#
- NMSE_LOSS = 'nmse_loss'#
- NRMSE_LOSS = 'nrmse_loss'#
- PSNR_LOSS = 'psnr_loss'#
- SNR_LOSS = 'snr_loss'#
- SSIM_3D_LOSS = 'ssim_3d_loss'#
- SSIM_LOSS = 'ssim_loss'#