direct.nn package#

Subpackages#

Submodules#

direct.nn.get_nn_model_config module#

direct.nn.get_nn_model_config module.

direct.nn.mri_models module#

MRI model engine of DIRECT.

class direct.nn.mri_models.MRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: Engine

Engine for MRI models.

Each child class should implement their own forward_function().

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits MRIModelEngine.

Parameters:
  • cfg (BaseConfig) – Configuration object.

  • model (Module) – Model.

  • device (str) – Device. Can be "cuda:{idx}" or "cpu".

  • forward_operator (Optional[Callable]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

forward_function(data)[source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Optional[Tensor], Optional[Tensor]]

build_loss()[source]#
Return type:

Dict

compute_sensitivity_map(sensitivity_map)[source]#

Computes sensitivity maps \(\{S^k\}_{k=1}^{n_c}\) if sensitivity_model is available.

\(\{S^k\}_{k=1}^{n_c}\) are normalized such that

\[\sum_{k=1}^{n_c}S^k {S^k}^* = I.\]
Parameters:

sensitivity_map (Tensor) – Sensitivity maps of shape (batch, coil, height, width, complex=2).

Return type:

Tensor

Returns:

Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2).

reconstruct_volumes(data_loader, loss_fns=None, regularizer_fns=None, add_target=True, crop=None)[source]#

Validation process. Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.

Parameters:
  • data_loader (DataLoader) – Data loader.

  • loss_fns (Optional[Dict[str, Callable]]) – Callable loss functions.

  • regularizer_fns (Optional[Dict[str, Callable]]) – Callable regularization functions.

  • add_target (bool) – If true, will add the target to the output.

  • crop (Optional[str]) – Crop type.

Yields:

(curr_volume, [curr_target,] loss_dict_list, filename) – torch.Tensor, [torch.Tensor,], dict, pathlib.Path

evaluate(data_loader, loss_fns)[source]#

Validation process.

Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.

Parameters:
  • data_loader (DataLoader) – Data loader.

  • loss_fns (Optional[Dict[str, Callable]]) – Callable loss functions.

Returns:

loss_dict, all_gathered_metrics, visualize_slices, visualize_target

compute_model_per_coil(model_name, data)[source]#

Performs forward pass of model model_name in self.models per coil.

Parameters:
  • model_name (str) – Model to run.

  • data (Tensor) – Multi-coil data of shape (batch, coil, complex=2, height, width).

Return type:

Tensor

Returns:

Computed output per coil.

compute_loss_on_data(loss_dict, loss_fns, data, output_image=None, output_kspace=None, weight=1.0)[source]#
Return type:

Dict[str, Tensor]

direct.nn.types module#

direct.nn.types module.

class direct.nn.types.ActivationType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

RELU = 'relu'#
PRELU = 'prelu'#
LEAKY_RELU = 'leaky_relu'#
class direct.nn.types.ModelName(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

UNET = 'unet'#
NORMUNET = 'normunet'#
RESNET = 'resnet'#
DIDN = 'didn'#
CONV = 'conv'#
class direct.nn.types.InitType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

INPUT_IMAGE = 'input_image'#
SENSE = 'sense'#
ZERO_FILLED = 'zero_filled'#
ZEROS = 'zeros'#
class direct.nn.types.LossFunType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

L1_LOSS = 'l1_loss'#
KSPACE_L1_LOSS = 'kspace_l1_loss'#
L2_LOSS = 'l2_loss'#
KSPACE_L2_LOSS = 'kspace_l2_loss'#
SSIM_LOSS = 'ssim_loss'#
SSIM_3D_LOSS = 'ssim_3d_loss'#
GRAD_L1_LOSS = 'grad_l1_loss'#
GRAD_L2_LOSS = 'grad_l2_loss'#
NMSE_LOSS = 'nmse_loss'#
KSPACE_NMSE_LOSS = 'kspace_nmse_loss'#
NRMSE_LOSS = 'nrmse_loss'#
KSPACE_NRMSE_LOSS = 'kspace_nrmse_loss'#
NMAE_LOSS = 'nmae_loss'#
KSPACE_NMAE_LOSS = 'kspace_nmae_loss'#
SNR_LOSS = 'snr_loss'#
PSNR_LOSS = 'psnr_loss'#
HFEN_L1_LOSS = 'hfen_l1_loss'#
HFEN_L2_LOSS = 'hfen_l2_loss'#
HFEN_L1_NORM_LOSS = 'hfen_l1_norm_loss'#
HFEN_L2_NORM_LOSS = 'hfen_l2_norm_loss'#

Module contents#