direct.nn.recurrentvarnet package#

Submodules#

direct.nn.recurrentvarnet.config module#

class direct.nn.recurrentvarnet.config.RecurrentVarNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: int = 15, recurrent_hidden_channels: int = 64, recurrent_num_layers: int = 4, no_parameter_sharing: bool = True, learned_initializer: bool = True, initializer_initialization: Optional[str] = <InitType.SENSE: 'sense'>, initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), initializer_multiscale: int = 1, normalized: bool = False)[source][source]#

Bases: ModelConfig

initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64)#
initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4)#
initializer_initialization: Optional[str] = 'sense'#
initializer_multiscale: int = 1#
learned_initializer: bool = True#
no_parameter_sharing: bool = True#
normalized: bool = False#
num_steps: int = 15#
recurrent_hidden_channels: int = 64#
recurrent_num_layers: int = 4#

direct.nn.recurrentvarnet.recurrentvarnet module#

class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentInit(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source][source]#

Bases: Module

Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1].

The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.

References

[1]

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

forward(x)[source][source]#

Computes initialization for recurrent unit given input x.

Parameters:
x: torch.Tensor

Initialization for RecurrentInit.

Returns:
out: torch.Tensor

Initial recurrent hidden state from input x.

Return type:

Tensor

training: bool#
class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNet(forward_operator, backward_operator, in_channels=2, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source][source]#

Bases: Module

Recurrent Variational Network implementation as presented in [1].

References

[1]

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

compute_sense_init(kspace, sensitivity_map)[source][source]#

Computes sense initialization \(x_{\text{SENSE}}\): :rtype: Tensor

\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]

where \(y^k\) denotes the data from coil \(k\).

Parameters:
kspace: torch.Tensor

k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

Returns:
input_image: torch.Tensor

Sense initialization \(x_{\text{SENSE}}\).

forward(masked_kspace, sampling_mask, sensitivity_map, **kwargs)[source][source]#

Computes forward pass of RecurrentVarNet.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

Returns:
kspace_prediction: torch.Tensor

k-space prediction.

Return type:

Tensor

training: bool#
class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNetBlock(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source][source]#

Bases: Module

Recurrent Variational Network Block \(\mathcal{H}_{\theta_{t}}\) as presented in [1].

References

[1]

Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.

forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map, hidden_state, coil_dim=1, spatial_dims=(2, 3))[source][source]#

Computes forward pass of RecurrentVarNetBlock.

Parameters:
current_kspace: torch.Tensor

Current k-space prediction of shape (N, coil, height, width, complex=2).

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

hidden_state: torch.Tensor or None

Recurrent unit hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.

coil_dim: int

Coil dimension. Default: 1.

spatial_dims: tuple of ints

Spatial dimensions. Default: (2, 3).

Returns:
new_kspace: torch.Tensor

New k-space prediction of shape (N, coil, height, width, complex=2).

hidden_state: torch.Tensor

Next hidden state of shape (N, hidden_channels, height, width, num_layers).

Return type:

Tuple[Tensor, Tensor]

training: bool#

direct.nn.recurrentvarnet.recurrentvarnet_engine module#

class direct.nn.recurrentvarnet.recurrentvarnet_engine.RecurrentVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

Recurrent Variational Network Engine.

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, Tensor]

Module contents#