direct.nn.recurrentvarnet package#
Submodules#
direct.nn.recurrentvarnet.config module#
- class direct.nn.recurrentvarnet.config.RecurrentVarNetConfig(model_name='???', engine_name=None, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=True, initializer_initialization=InitType.SENSE, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False)[source]#
Bases:
ModelConfig- num_steps = 15#
- recurrent_num_layers = 4#
- no_parameter_sharing = True#
- learned_initializer = True#
- initializer_initialization = 'sense'#
- initializer_channels = (32, 32, 64, 64)#
- initializer_dilations = (1, 1, 2, 4)#
- initializer_multiscale = 1#
- normalized = False#
- __init__(model_name='???', engine_name=None, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=True, initializer_initialization=InitType.SENSE, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False)#
direct.nn.recurrentvarnet.recurrentvarnet module#
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentInit(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source]#
Bases:
ModuleRecurrent State Initializer (RSI) module of Recurrent Variational Network as presented in [1]_.
The RSI module learns to initialize the recurrent hidden state \(h_0\), input of the first RecurrentVarNetBlock of the RecurrentVarNet.
References:
- __init__(in_channels, out_channels, channels, dilations, depth=2, multiscale_depth=1)[source]#
Inits
RecurrentInit.- Parameters:
in_channels (
int) – Input channels.out_channels (
int) – Number of hidden channels of the recurrent unit ofRecurrentVarNetBlock.channels (
Tuple[int,...]) – Channels \(n_d\) in the convolutional layers of initializer.dilations (
Tuple[int,...]) – Dilations \(p\) of the convolutional layers of the initializer.depth (
int) – RecurrentVarNet Block number of layers \(n_l\). Default:2.multiscale_depth (
int) – Number of feature layers to aggregate for the output, if1, multi-scale context aggregation is disabled. Default:1.
- forward(x)[source]#
Computes initialization for recurrent unit given input.
- Parameters:
x (
Tensor) – Initialization forRecurrentInit.- Return type:
Tensor- Returns:
Initial recurrent hidden state from input.
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNet(forward_operator, backward_operator, in_channels=COMPLEX_SIZE, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source]#
Bases:
ModuleRecurrent Variational Network implementation as presented in [1]_.
References:
[1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- __init__(forward_operator, backward_operator, in_channels=COMPLEX_SIZE, num_steps=15, recurrent_hidden_channels=64, recurrent_num_layers=4, no_parameter_sharing=True, learned_initializer=False, initializer_initialization=None, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source]#
Inits
RecurrentVarNet.- Parameters:
forward_operator (
Callable) – CallableOperator. (Backward)
backward_operator (
Callable) – CallableOperator.
num_steps (
int) – intiterations (Number of) – math:T.
in_channels (
int) – intdata. (Input channel number. Default is 2 for complex)
recurrent_hidden_channels (
int) – intDefault (is disabled.) –
recurrent_num_layers (
int) – int( (Number of layers for the recurrent unit of the RecurrentVarNet Block) – math:n_l). Default: 4.
no_parameter_sharing (
bool) – boolFalse (If) – class:RecurrentVarNetBlock is used for all num_steps. Default: True.
same (the) – class:RecurrentVarNetBlock is used for all num_steps. Default: True.
learned_initializer (
bool) – boolDefault – False.
initializer_initialization (
Optional[InitType]) – str, Optional'sense' (Type of initialization for the RSI module. Can be either)
'input-image'. ('zero-filled' or)
Default – None.
initializer_channels (
Optional[Tuple[int,...]]) – tupleChannels – math:n_d in the convolutional layers of the RSI module. Default: (32, 32, 64, 64).
initializer_dilations (
Optional[Tuple[int,...]]) – tupleDilations – math:p of the convolutional layers of the RSI module. Default: (1, 1, 2, 4).
initializer_multiscale (
int) – intoutput (RSI module number of feature layers to aggregate for the)
1 (if)
aggregation (multi-scale context)
Default
normalized (
bool) – boolTrue (If) – class:NormConv2dGRU will be used as a regularizer in the
RecurrentVarNetBlocks.
:param : class:NormConv2dGRU will be used as a regularizer in the
RecurrentVarNetBlocks. :param Default: False.
- compute_sense_init(kspace, sensitivity_map)[source]#
Computes sense initialization \(x_{\text{SENSE}}\):
\[x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k\]where \(y^k\) denotes the data from coil \(k\).
- Parameters:
kspace (
Tensor) – torch.Tensorshape (Sensitivity map of)
sensitivity_map (
Tensor) – torch.Tensorshape
- Returns:
torch.Tensor Sense initialization \(x_{\text{SENSE}}\).
- Return type:
input_image
- forward(masked_kspace, sampling_mask, sensitivity_map, **kwargs)[source]#
Computes forward pass of
RecurrentVarNet.- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Coil sensitivities of)
sampling_mask (
Tensor) – torch.Tensorshape
sensitivity_map (
Tensor) – torch.Tensorshape
- Returns:
torch.Tensor k-space prediction.
- Return type:
kspace_prediction
- class direct.nn.recurrentvarnet.recurrentvarnet.RecurrentVarNetBlock(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source]#
Bases:
ModuleRecurrent Variational Network Block \(\mathcal{H}_{\theta_{t}}\) as presented in [1]_.
References:
[1] Yiasemis, George, et al. “Recurrent Variational Network: A Deep Learning Inverse Problem Solver Applied to the Task of Accelerated MRI Reconstruction.” ArXiv:2111.09639 [Physics], Nov. 2021. arXiv.org, http://arxiv.org/abs/2111.09639.
- __init__(forward_operator, backward_operator, in_channels=2, hidden_channels=64, num_layers=4, normalized=False)[source]#
Inits RecurrentVarNetBlock.
- Parameters:
forward_operator (
Callable) – CallableTransform. (Backward Fourier)
backward_operator (
Callable) – CallableTransform.
in_channels (
int) – int,data. (Input channel number. Default is 2 for complex)
hidden_channels (
int) – int,Default (Hidden channels.) –
num_layers (
int) – int,of (Number of layers) – math:n_l recurrent unit. Default: 4.
normalized (
bool) – boolTrue (If) – class:NormConv2dGRU will be used as a regularizer. Default: False.
:param : class:NormConv2dGRU will be used as a regularizer. Default: False.
- forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map, hidden_state, coil_dim=1, spatial_dims=(2, 3))[source]#
Computes forward pass of RecurrentVarNetBlock.
- Parameters:
current_kspace (
Tensor) – torch.Tensorshape (Recurrent unit hidden state of)
masked_kspace (
Tensor) – torch.Tensorshape
sampling_mask (
Tensor) – torch.Tensorshape
sensitivity_map (
Tensor) – torch.Tensorshape
hidden_state (
Optional[Tensor]) – torch.Tensor or Noneshape
coil_dim (
int) – intDefault (Spatial dimensions.)
spatial_dims (
Tuple[int,int]) – tuple of intsDefault – (2, 3).
- Returns:
torch.Tensor New k-space prediction of shape (N, coil, height, width, complex=2). hidden_state: torch.Tensor Next hidden state of shape (N, hidden_channels, height, width, num_layers).
- Return type:
new_kspace
direct.nn.recurrentvarnet.recurrentvarnet_engine module#
- class direct.nn.recurrentvarnet.recurrentvarnet_engine.RecurrentVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
MRIModelEngineRecurrent Variational Network Engine.