direct.nn.recurrentvarnet package#
Submodules#
direct.nn.recurrentvarnet.config module#
- class direct.nn.recurrentvarnet.config.RecurrentVarNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: int = 15, recurrent_hidden_channels: int = 64, recurrent_num_layers: int = 4, no_parameter_sharing: bool = True, learned_initializer: bool = True, initializer_initialization: Optional[str] = <InitType.SENSE: 'sense'>, initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), initializer_multiscale: int = 1, normalized: bool = False)[source][source]#
Bases:
ModelConfig
-
initializer_channels:
Optional
[Tuple
[int
,...
]] = (32, 32, 64, 64)#
-
initializer_dilations:
Optional
[Tuple
[int
,...
]] = (1, 1, 2, 4)#
-
initializer_initialization:
Optional
[str
] = 'sense'#
-
initializer_multiscale:
int
= 1#
-
learned_initializer:
bool
= True#
-
no_parameter_sharing:
bool
= True#
-
normalized:
bool
= False#
-
num_steps:
int
= 15#
-
recurrent_num_layers:
int
= 4#
-
initializer_channels:
direct.nn.recurrentvarnet.recurrentvarnet module#
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentInit(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source][source]#
Bases:
Module
Recurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1].
The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.
References
[1]Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- forward(x)[source][source]#
Computes initialization for recurrent unit given input x.
- Parameters:
- x: torch.Tensor
Initialization for RecurrentInit.
- Returns:
- out: torch.Tensor
Initial recurrent hidden state from input x.
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNet(forward_operator, backward_operator, in_channels=2, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source][source]#
Bases:
Module
Recurrent Variational Network implementation as presented in [1].
References
[1]Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- compute_sense_init(kspace, sensitivity_map)[source][source]#
Computes sense initialization \(x_{\text{SENSE}}\): :rtype:
Tensor
\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]where \(y^k\) denotes the data from coil \(k\).
- Parameters:
- kspace: torch.Tensor
k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- Returns:
- input_image: torch.Tensor
Sense initialization \(x_{\text{SENSE}}\).
- forward(masked_kspace, sampling_mask, sensitivity_map, **kwargs)[source][source]#
Computes forward pass of
RecurrentVarNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- Returns:
- kspace_prediction: torch.Tensor
k-space prediction.
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNetBlock(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source][source]#
Bases:
Module
Recurrent Variational Network Block \(\mathcal{H}_{\theta_{t}}\) as presented in [1].
References
[1]Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map, hidden_state, coil_dim=1, spatial_dims=(2, 3))[source][source]#
Computes forward pass of RecurrentVarNetBlock.
- Parameters:
- current_kspace: torch.Tensor
Current k-space prediction of shape (N, coil, height, width, complex=2).
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- hidden_state: torch.Tensor or None
Recurrent unit hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.
- coil_dim: int
Coil dimension. Default: 1.
- spatial_dims: tuple of ints
Spatial dimensions. Default: (2, 3).
- Returns:
- new_kspace: torch.Tensor
New k-space prediction of shape (N, coil, height, width, complex=2).
- hidden_state: torch.Tensor
Next hidden state of shape (N, hidden_channels, height, width, num_layers).
- Return type:
Tuple
[Tensor
,Tensor
]
-
training:
bool
#
direct.nn.recurrentvarnet.recurrentvarnet_engine module#
- class direct.nn.recurrentvarnet.recurrentvarnet_engine.RecurrentVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
Recurrent Variational Network Engine.