direct.nn.cirim package#

Submodules#

direct.nn.cirim.cirim module#

class direct.nn.cirim.cirim.CIRIM(forward_operator, backward_operator, depth=2, in_channels=2, time_steps=8, recurrent_hidden_channels=64, num_cascades=8, no_parameter_sharing=True, **kwargs)[source][source]#

Bases: Module

Cascades of Independently Recurrent Inference Machines implementation as presented in [1].

References

[1]

Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1

forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#
Parameters:
masked_kspacetorch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_masktorch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_maptorch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

Returns:
imspace_prediction: torch.Tensor

imspace prediction.

Return type:

List[List[Union[Tensor, Any]]]

training: bool#
class direct.nn.cirim.cirim.ConvNonlinear(input_size, features, kernel_size, dilation, bias)[source][source]#

Bases: Module

A convolutional layer with nonlinearity.

forward(_input)[source][source]#

Forward pass of the convolutional layer.

Parameters:
_input: torch.Tensor

Input tensor. (batch_size, seq_len, input_size)

Returns:
output: torch.Tensor

Output tensor. (batch_size, seq_len, features)

reset_parameters()[source][source]#

Resets the parameters of the convolutional layer.

training: bool#
class direct.nn.cirim.cirim.ConvRNNStack(convs, recurrent)[source][source]#

Bases: Module

A stack of convolutional RNNs.

Takes as input a sequence of recurrent and convolutional layers.

forward(_input, hidden)[source][source]#
Parameters:
_input: torch.Tensor

Input tensor. (batch_size, seq_len, input_size)

hidden: torch.Tensor

Hidden state. (num_layers * num_directions, batch_size, hidden_size)

Returns:
output: torch.Tensor

Output tensor. (batch_size, seq_len, hidden_size)

training: bool#
class direct.nn.cirim.cirim.IndRNNCell(in_channels, hidden_channels, kernel_size=1, dilation=1, bias=True)[source][source]#

Bases: Module

Base class for Independently RNN cells as presented in [1].

References

[1]

Li, S. et al. (2018) ‘Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN’, Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, (1), pp. 5457–5466. doi: 10.1109/CVPR.2018.00572.

forward(_input, hx)[source][source]#

Forward pass of the cell.

Parameters:
_input: torch.Tensor

Input tensor. (batch_size, seq_len, input_size), tensor containing input features.

hx: torch.Tensor

Hidden state. (batch_size, hidden_channels, 1, 1), tensor containing hidden state features.

Returns:
output: torch.Tensor

Output tensor. (batch_size, seq_len, hidden_channels), tensor containing the next hidden state.

static orthotogonalize_weights(weights, chunks=1)[source][source]#

Orthogonalize weights.

Parameters:
weights: torch.Tensor

The weights to orthogonalize.

chunks: int

Number of chunks. Default: 1.

Returns:
weights: torch.Tensor

The orthogonalized weights.

reset_parameters()[source][source]#

Reset the parameters.

training: bool#
class direct.nn.cirim.cirim.RIMBlock(forward_operator, backward_operator, depth=2, in_channels=2, hidden_channels=64, time_steps=4, no_parameter_sharing=False)[source][source]#

Bases: Module

Recurrent Inference Machines block as presented in [1].

References

[1]

Karkalousos, D. et al. (2021) ‘Assessment of Data Consistency through Cascades of Independently Recurrent Inference Machines for fast and robust accelerated MRI reconstruction’. Available at: https://arxiv.org/abs/2111.15498v1

forward(current_prediction, masked_kspace, sampling_mask, sensitivity_map, hidden_state, parameter_sharing=False, coil_dim=1, spatial_dims=(2, 3))[source][source]#
Parameters:
current_predictiontorch.Tensor

Current k-space.

masked_kspacetorch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_masktorch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_maptorch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

hidden_state: torch.Tensor or None

IndRNN hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.

parameter_sharing: bool

If True, the weights of the convolutional layers are shared between the forward and backward pass.

coil_dim: int

Coil dimension. Default: 1.

spatial_dims: tuple of ints

Spatial dimensions. Default: (2, 3).

Returns:
new_kspace: torch.Tensor

New k-space prediction of shape (N, coil, height, width, complex=2).

hidden_state: torch.Tensor or None

Next hidden state of shape (N, hidden_channels, height, width, num_layers) if parameter_sharing else None.

Return type:

Union[Tuple[List, None], Tuple[List, Union[List, Tensor]]]

training: bool#

direct.nn.cirim.cirim_engine module#

class direct.nn.cirim.cirim_engine.CIRIMEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

Cascades of Independently Recurrent Inference Machines Engine.

direct.nn.cirim.config module#

class direct.nn.cirim.config.CIRIMConfig(model_name: str = '???', engine_name: str | None = None, time_steps: int = 8, depth: int = 2, recurrent_hidden_channels: int = 64, num_cascades: int = 8, no_parameter_sharing: bool = True)[source][source]#

Bases: ModelConfig

depth: int = 2#
no_parameter_sharing: bool = True#
num_cascades: int = 8#
recurrent_hidden_channels: int = 64#
time_steps: int = 8#

Module contents#