direct.nn.kikinet package#

Submodules#

direct.nn.kikinet.config module#

class direct.nn.kikinet.config.KIKINetConfig(model_name='???', engine_name=None, num_iter=10, image_model_architecture='MWCNN', kspace_model_architecture='UNET', image_mwcnn_hidden_channels=16, image_mwcnn_num_scales=4, image_mwcnn_bias=True, image_mwcnn_batchnorm=False, image_unet_num_filters=8, image_unet_num_pool_layers=4, image_unet_dropout_probability=0.0, kspace_conv_hidden_channels=16, kspace_conv_n_convs=4, kspace_conv_batchnorm=False, kspace_didn_hidden_channels=64, kspace_didn_num_dubs=6, kspace_didn_num_convs_recon=9, kspace_unet_num_filters=8, kspace_unet_num_pool_layers=4, kspace_unet_dropout_probability=0.0, normalize=False)[source]#

Bases: ModelConfig

num_iter = 10#
image_model_architecture = 'MWCNN'#
kspace_model_architecture = 'UNET'#
image_mwcnn_hidden_channels = 16#
image_mwcnn_num_scales = 4#
image_mwcnn_bias = True#
image_mwcnn_batchnorm = False#
image_unet_num_filters = 8#
image_unet_num_pool_layers = 4#
image_unet_dropout_probability = 0.0#
kspace_conv_hidden_channels = 16#
kspace_conv_n_convs = 4#
kspace_conv_batchnorm = False#
kspace_didn_hidden_channels = 64#
kspace_didn_num_dubs = 6#
kspace_didn_num_convs_recon = 9#
kspace_unet_num_filters = 8#
kspace_unet_num_pool_layers = 4#
kspace_unet_dropout_probability = 0.0#
normalize = False#
__init__(model_name='???', engine_name=None, num_iter=10, image_model_architecture='MWCNN', kspace_model_architecture='UNET', image_mwcnn_hidden_channels=16, image_mwcnn_num_scales=4, image_mwcnn_bias=True, image_mwcnn_batchnorm=False, image_unet_num_filters=8, image_unet_num_pool_layers=4, image_unet_dropout_probability=0.0, kspace_conv_hidden_channels=16, kspace_conv_n_convs=4, kspace_conv_batchnorm=False, kspace_didn_hidden_channels=64, kspace_didn_num_dubs=6, kspace_didn_num_convs_recon=9, kspace_unet_num_filters=8, kspace_unet_num_pool_layers=4, kspace_unet_dropout_probability=0.0, normalize=False)#

direct.nn.kikinet.kikinet module#

class direct.nn.kikinet.kikinet.KIKINet(forward_operator, backward_operator, image_model_architecture='MWCNN', kspace_model_architecture='DIDN', num_iter=2, normalize=False, **kwargs)[source]#

Bases: Module

Based on KIKINet implementation [1]. Modified to work with multi-coil k-space data.

References:

__init__(forward_operator, backward_operator, image_model_architecture='MWCNN', kspace_model_architecture='DIDN', num_iter=2, normalize=False, **kwargs)[source]#

Inits KIKINet.

Parameters:
  • forward_operator (Callable) – Forward Operator.

  • backward_operator (Callable) – Backward Operator.

  • image_model_architecture (str) – Image model architecture. Currently only implemented for "MWCNN" and "(NORM)UNET". Default: "MWCNN".

  • kspace_model_architecture (str) – Kspace model architecture. Currently only implemented for "CONV" and "DIDN" and "(NORM)UNET". Default: "DIDN".

  • num_iter (int) – Number of unrolled iterations. Default: 2.

  • normalize (bool) – If True, input is normalised based on input scaling_factor. Default: False.

  • **kwargs – Keyword arguments for model architectures.

forward(masked_kspace, sampling_mask, sensitivity_map, scaling_factor=None)[source]#

Computes forward pass of KIKINet.

Parameters:
  • masked_kspace (Tensor) – torch.Tensor

  • shape (Scaling factor of)

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • scaling_factor (Optional[Tensor]) – Optional[torch.Tensor]

  • shape – None.

Returns:

torch.Tensor Output image of shape (N, height, width, complex=2).

Return type:

image

direct.nn.kikinet.kikinet_engine module#

class direct.nn.kikinet.kikinet_engine.KIKINetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

KIKINet Engine.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits :class:`KIKINetEngine.

forward_function(data)[source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#