direct.nn.recurrentvarnet package#

Submodules#

direct.nn.recurrentvarnet.config module#

class direct.nn.recurrentvarnet.config.RecurrentVarNetConfig(model_name='???', engine_name=None, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=True, initializer_initialization=InitType.SENSE, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False)[source]#

Bases: ModelConfig

num_steps = 15#
recurrent_hidden_channels = 64#
recurrent_num_layers = 4#
no_parameter_sharing = True#
learned_initializer = True#
initializer_initialization = 'sense'#
initializer_channels = (32, 32, 64, 64)#
initializer_dilations = (1, 1, 2, 4)#
initializer_multiscale = 1#
normalized = False#
__init__(model_name='???', engine_name=None, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=True, initializer_initialization=InitType.SENSE, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False)#

direct.nn.recurrentvarnet.recurrentvarnet module#

class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentInit(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source]#

Bases: Module

Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1]_.

The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.

References:

__init__(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source]#

Inits RecurrentInit.

Parameters:
  • in_channels (int) – Input channels.

  • out_channels (int) – Number of hidden channels of the recurrent unit of RecurrentVarNet Block.

  • channels (Tuple[int, ...]) – Channels \(n_d\) in the convolutional layers of initializer.

  • dilations (Tuple[int, ...]) – Dilations \(p\) of the convolutional layers of the initializer.

  • depth (int) – RecurrentVarNet Block number of layers \(n_l\). Default: 2.

  • multiscale_depth (int) – Number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. Default: 1.

forward(x)[source]#

Computes initialization for recurrent unit given input.

Parameters:

x (Tensor) – Initialization for RecurrentInit.

Return type:

Tensor

Returns:

Initial recurrent hidden state from input.

class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNet(forward_operator, backward_operator, in_channels=COMPLEX_SIZE, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source]#

Bases: Module

Recurrent Variational Network implementation as presented in [1]_.

References:

__init__(forward_operator, backward_operator, in_channels=COMPLEX_SIZE, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source]#

Inits RecurrentVarNet.

Parameters:
  • forward_operator (Callable) – Callable

  • Operator. (Backward)

  • backward_operator (Callable) – Callable

  • Operator.

  • num_steps (int) – int

  • iterations (Number of) – math:T.

  • in_channels (int) – int

  • data. (Input channel number. Default is 2 for complex)

  • recurrent_hidden_channels (int) – int

  • Default (is disabled.) –

  • recurrent_num_layers (int) – int

  • ( (Number of layers for the recurrent unit of the RecurrentVarNet Block) – math:n_l). Default: 4.

  • no_parameter_sharing (bool) – bool

  • False (If) – class:RecurrentVarNetBlock is used for all num_steps. Default: True.

  • same (the) – class:RecurrentVarNetBlock is used for all num_steps. Default: True.

  • learned_initializer (bool) – bool

  • Default – False.

  • initializer_initialization (Optional[InitType]) – str, Optional

  • 'sense' (Type of initialization for the RSI module. Can be either)

  • 'input-image'. ('zero-filled' or)

  • Default – None.

  • initializer_channels (Optional[Tuple[int, ...]]) – tuple

  • Channels – math:n_d in the convolutional layers of the RSI module. Default: (32, 32, 64, 64).

  • initializer_dilations (Optional[Tuple[int, ...]]) – tuple

  • Dilations – math:p of the convolutional layers of the RSI module. Default: (1, 1, 2, 4).

  • initializer_multiscale (int) – int

  • output (RSI module number of feature layers to aggregate for the)

  • 1 (if)

  • aggregation (multi-scale context)

  • Default

  • normalized (bool) – bool

  • True (If) – class:NormConv2dGRU will be used as a regularizer in the RecurrentVarNetBlocks.

:param : class:NormConv2dGRU will be used as a regularizer in the RecurrentVarNetBlocks. :param Default: False.

compute_sense_init(kspace, sensitivity_map)[source]#

Computes sense initialization \(x_{\text{SENSE}}\):

\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]

where \(y^k\) denotes the data from coil \(k\).

Parameters:
  • kspace (Tensor) – torch.Tensor

  • shape (Sensitivity map of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

Returns:

torch.Tensor Sense initialization \(x_{\text{SENSE}}\).

Return type:

input_image

forward(masked_kspace, sampling_mask, sensitivity_map, **kwargs)[source]#

Computes forward pass of RecurrentVarNet.

Parameters:
  • masked_kspace (Tensor) – torch.Tensor

  • shape (Coil sensitivities of)

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

Returns:

torch.Tensor k-space prediction.

Return type:

kspace_prediction

class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNetBlock(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source]#

Bases: Module

Recurrent Variational Network Block \(\mathcal{H}_{\theta_{t}}\) as presented in [1]_.

References:

__init__(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source]#

Inits RecurrentVarNetBlock.

Parameters:
  • forward_operator (Callable) – Callable

  • Transform. (Backward Fourier)

  • backward_operator (Callable) – Callable

  • Transform.

  • in_channels (int) – int,

  • data. (Input channel number. Default is 2 for complex)

  • hidden_channels (int) – int,

  • Default (Hidden channels.) –

  • num_layers (int) – int,

  • of (Number of layers) – math:n_l recurrent unit. Default: 4.

  • normalized (bool) – bool

  • True (If) – class:NormConv2dGRU will be used as a regularizer. Default: False.

:param : class:NormConv2dGRU will be used as a regularizer. Default: False.

forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map, hidden_state, coil_dim=1, spatial_dims=(2, 3))[source]#

Computes forward pass of RecurrentVarNetBlock.

Parameters:
  • current_kspace (Tensor) – torch.Tensor

  • shape (Recurrent unit hidden state of)

  • masked_kspace (Tensor) – torch.Tensor

  • shape

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • hidden_state (Optional[Tensor]) – torch.Tensor or None

  • shape

  • coil_dim (int) – int

  • Default (Spatial dimensions.)

  • spatial_dims (Tuple[int, int]) – tuple of ints

  • Default – (2, 3).

Returns:

torch.Tensor New k-space prediction of shape (N, coil, height, width, complex=2). hidden_state: torch.Tensor Next hidden state of shape (N, hidden_channels, height, width, num_layers).

Return type:

new_kspace

direct.nn.recurrentvarnet.recurrentvarnet_engine module#

class direct.nn.recurrentvarnet.recurrentvarnet_engine.RecurrentVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

Recurrent Variational Network Engine.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits :class:`RecurrentVarNetEngine.

forward_function(data)[source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, Tensor]

Module contents#