direct.nn package#

Subpackages#

Submodules#

direct.nn.get_nn_model_config module#

direct.nn.get_nn_model_config module.

direct.nn.mri_models module#

MRI model engine of DIRECT.

class direct.nn.mri_models.MRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: Engine

Engine for MRI models.

Each child class should implement their own forward_function().

build_loss()[source][source]#
Return type:

Dict

compute_loss_on_data(loss_dict, loss_fns, data, output_image=None, output_kspace=None, weight=1.0)[source][source]#
Return type:

Dict[str, Tensor]

compute_model_per_coil(model_name, data)[source][source]#

Performs forward pass of model model_name in self.models per coil.

Parameters:
model_name: str

Model to run.

data: torch.Tensor

Multi-coil data of shape (batch, coil, complex=2, height, width).

Returns:
output: torch.Tensor

Computed output per coil.

Return type:

Tensor

compute_sensitivity_map(sensitivity_map)[source][source]#

Computes sensitivity maps \(\{S^k\}_{k=1}^{n_c}\) if sensitivity_model is available.

\(\{S^k\}_{k=1}^{n_c}\) are normalized such that :rtype: Tensor

\[\sum_{k=1}^{n_c}S^k {S^k}^* = I.\]
Parameters:
sensitivity_map: torch.Tensor

Sensitivity maps of shape (batch, coil, height, width, complex=2).

Returns:
sensitivity_map: torch.Tensor

Normalized and refined sensitivity maps of shape (batch, coil, height, width, complex=2).

evaluate(data_loader, loss_fns)[source][source]#

Validation process.

Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.

Parameters:
data_loader: DataLoader
loss_fns: Dict[str, Callable], optional
Returns:
loss_dict, all_gathered_metrics, visualize_slices, visualize_target
forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Optional[Tensor], Optional[Tensor]]

reconstruct_volumes(data_loader, loss_fns=None, regularizer_fns=None, add_target=True, crop=None)[source][source]#

Validation process. Assumes that each batch only contains slices of the same volume AND that these are sequentially ordered.

Parameters:
data_loader: DataLoader
loss_fns: Dict[str, Callable], optional
regularizer_fns: Dict[str, Callable], optional
add_target: bool

If true, will add the target to the output

crop: str, optional

Crop type.

Yields:
(curr_volume, [curr_target,] loss_dict_list, filename): torch.Tensor, [torch.Tensor,], dict, pathlib.Path

direct.nn.types module#

direct.nn.types module.

class direct.nn.types.ActivationType(value)[source][source]#

Bases: DirectEnum

An enumeration.

LEAKY_RELU = 'leaky_relu'#
PRELU = 'prelu'#
RELU = 'relu'#
class direct.nn.types.InitType(value)[source][source]#

Bases: DirectEnum

An enumeration.

INPUT_IMAGE = 'input_image'#
SENSE = 'sense'#
ZEROS = 'zeros'#
ZERO_FILLED = 'zero_filled'#
class direct.nn.types.LossFunType(value)[source][source]#

Bases: DirectEnum

An enumeration.

GRAD_L1_LOSS = 'grad_l1_loss'#
GRAD_L2_LOSS = 'grad_l2_loss'#
HFEN_L1_LOSS = 'hfen_l1_loss'#
HFEN_L1_NORM_LOSS = 'hfen_l1_norm_loss'#
HFEN_L2_LOSS = 'hfen_l2_loss'#
HFEN_L2_NORM_LOSS = 'hfen_l2_norm_loss'#
KSPACE_L1_LOSS = 'kspace_l1_loss'#
KSPACE_L2_LOSS = 'kspace_l2_loss'#
KSPACE_NMAE_LOSS = 'kspace_nmae_loss'#
KSPACE_NMSE_LOSS = 'kspace_nmse_loss'#
KSPACE_NRMSE_LOSS = 'kspace_nrmse_loss'#
L1_LOSS = 'l1_loss'#
L2_LOSS = 'l2_loss'#
NMAE_LOSS = 'nmae_loss'#
NMSE_LOSS = 'nmse_loss'#
NRMSE_LOSS = 'nrmse_loss'#
PSNR_LOSS = 'psnr_loss'#
SNR_LOSS = 'snr_loss'#
SSIM_3D_LOSS = 'ssim_3d_loss'#
SSIM_LOSS = 'ssim_loss'#
class direct.nn.types.ModelName(value)[source][source]#

Bases: DirectEnum

An enumeration.

CONV = 'conv'#
DIDN = 'didn'#
NORMUNET = 'normunet'#
RESNET = 'resnet'#
UNET = 'unet'#

Module contents#