direct.nn.cirim package#
Submodules#
direct.nn.cirim.cirim module#
- class direct.nn.cirim.cirim.CIRIM(forward_operator, backward_operator, depth=2, in_channels=2, time_steps=8, recurrent_hidden_channels=64, num_cascades=8, no_parameter_sharing=True, **kwargs)[source][source]#
Bases:
Module
Cascades of Independently Recurrent Inference Machines implementation as presented in [1].
References
[1]Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1
- forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#
- Parameters:
- masked_kspacetorch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_maptorch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- Returns:
- imspace_prediction: torch.Tensor
imspace prediction.
- Return type:
List
[List
[Union
[Tensor
,Any
]]]
-
training:
bool
#
- class direct.nn.cirim.cirim.ConvNonlinear(input_size, features, kernel_size, dilation, bias)[source][source]#
Bases:
Module
A convolutional layer with nonlinearity.
- forward(_input)[source][source]#
Forward pass of the convolutional layer.
- Parameters:
- _input: torch.Tensor
Input tensor. (batch_size, seq_len, input_size)
- Returns:
- output: torch.Tensor
Output tensor. (batch_size, seq_len, features)
-
training:
bool
#
- class direct.nn.cirim.cirim.ConvRNNStack(convs, recurrent)[source][source]#
Bases:
Module
A stack of convolutional RNNs.
Takes as input a sequence of recurrent and convolutional layers.
- forward(_input, hidden)[source][source]#
- Parameters:
- _input: torch.Tensor
Input tensor. (batch_size, seq_len, input_size)
- hidden: torch.Tensor
Hidden state. (num_layers * num_directions, batch_size, hidden_size)
- Returns:
- output: torch.Tensor
Output tensor. (batch_size, seq_len, hidden_size)
-
training:
bool
#
- class direct.nn.cirim.cirim.IndRNNCell(in_channels, hidden_channels, kernel_size=1, dilation=1, bias=True)[source][source]#
Bases:
Module
Base class for Independently RNN cells as presented in [1].
References
[1]Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), pp. 5457–5466. doi: 10.1109/CVPR.2018.00572.
- forward(_input, hx)[source][source]#
Forward pass of the cell.
- Parameters:
- _input: torch.Tensor
Input tensor. (batch_size, seq_len, input_size), tensor containing input features.
- hx: torch.Tensor
Hidden state. (batch_size, hidden_channels, 1, 1), tensor containing hidden state features.
- Returns:
- output: torch.Tensor
Output tensor. (batch_size, seq_len, hidden_channels), tensor containing the next hidden state.
- static orthotogonalize_weights(weights, chunks=1)[source][source]#
Orthogonalize weights.
- Parameters:
- weights: torch.Tensor
The weights to orthogonalize.
- chunks: int
Number of chunks. Default: 1.
- Returns:
- weights: torch.Tensor
The orthogonalized weights.
-
training:
bool
#
- class direct.nn.cirim.cirim.RIMBlock(forward_operator, backward_operator, depth=2, in_channels=2, hidden_channels=64, time_steps=4, no_parameter_sharing=False)[source][source]#
Bases:
Module
Recurrent Inference Machines block as presented in [1].
References
[1]Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1
- forward(current_prediction, masked_kspace, sampling_mask, sensitivity_map, hidden_state, parameter_sharing=False, coil_dim=1, spatial_dims=(2, 3))[source][source]#
- Parameters:
- current_predictiontorch.Tensor
Current k-space.
- masked_kspacetorch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_maptorch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- hidden_state: torch.Tensor or None
IndRNN hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.
- parameter_sharing: bool
If True, the weights of the convolutional layers are shared between the forward and backward pass.
- coil_dim: int
Coil dimension. Default: 1.
- spatial_dims: tuple of ints
Spatial dimensions. Default: (2, 3).
- Returns:
- new_kspace: torch.Tensor
New k-space prediction of shape (N, coil, height, width, complex=2).
- hidden_state: torch.Tensor or None
Next hidden state of shape (N, hidden_channels, height, width, num_layers) if parameter_sharing else None.
- Return type:
Union
[Tuple
[List
,None
],Tuple
[List
,Union
[List
,Tensor
]]]
-
training:
bool
#
direct.nn.cirim.cirim_engine module#
- class direct.nn.cirim.cirim_engine.CIRIMEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
Cascades of Independently Recurrent Inference Machines Engine.
direct.nn.cirim.config module#
- class direct.nn.cirim.config.CIRIMConfig(model_name: str = '???', engine_name: str | None = None, time_steps: int = 8, depth: int = 2, recurrent_hidden_channels: int = 64, num_cascades: int = 8, no_parameter_sharing: bool = True)[source][source]#
Bases:
ModelConfig
-
depth:
int
= 2#
-
no_parameter_sharing:
bool
= True#
-
num_cascades:
int
= 8#
-
time_steps:
int
= 8#
-
depth: