direct.nn.rim package#

Submodules#

direct.nn.rim.config module#

class direct.nn.rim.config.RIMConfig(model_name: str = '???', engine_name: str | None = None, hidden_channels: int = 16, length: int = 8, depth: int = 2, steps: int = 1, no_parameter_sharing: bool = False, instance_norm: bool = False, dense_connect: bool = False, whiten_input: bool = False, replication_padding: bool = True, image_initialization: str = 'zero_filled', scale_loglikelihood: float | None = None, learned_initializer: bool = False, initializer_channels: Tuple[int, ...] = (32, 32, 64, 64), initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4), initializer_multiscale: int = 1, normalized: bool = False)[source][source]#

Bases: ModelConfig

dense_connect: bool = False#
depth: int = 2#
hidden_channels: int = 16#
image_initialization: str = 'zero_filled'#
initializer_channels: Tuple[int, ...] = (32, 32, 64, 64)#
initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4)#
initializer_multiscale: int = 1#
instance_norm: bool = False#
learned_initializer: bool = False#
length: int = 8#
no_parameter_sharing: bool = False#
normalized: bool = False#
replication_padding: bool = True#
scale_loglikelihood: Optional[float] = None#
steps: int = 1#
whiten_input: bool = False#

direct.nn.rim.rim module#

class direct.nn.rim.rim.MRILogLikelihood(forward_operator, backward_operator)[source][source]#

Bases: Module

Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils:

\[\frac{1}{\sigma^2} \sum_{i}^{N_c} {S}_i^{\text{H}} \mathcal{F}^{-1} P^{*} (P \mathcal{F} S_i x_{\tau} - y_{\tau})\]

for each time step \(\tau\).

forward(input_image, masked_kspace, sensitivity_map, sampling_mask, loglikelihood_scaling=None)[source][source]#

Performs forward pass of MRILogLikelihood.

Parameters:
input_image: torch.Tensor

Initial or previous iteration of image with complex first of shape (N, complex, height, width).

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex).

sensitivity_map: torch.Tensor

Sensitivity Map of shape (N, coil, height, width, complex).

sampling_mask: torch.Tensor
loglikelihood_scaling: torch.Tensor

Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,).

Returns:
out: torch.Tensor

The MRI Loglikelihood.

Return type:

Tensor

training: bool#
class direct.nn.rim.rim.RIM(forward_operator, backward_operator, hidden_channels, x_channels=2, length=8, depth=1, no_parameter_sharing=True, instance_norm=False, dense_connect=False, skip_connections=True, replication_padding=True, image_initialization='zero_filled', learned_initializer=False, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source][source]#

Bases: Module

Recurrent Inference Machine Module as in [1].

References

[1]

Putzky, Patrick, and Max Welling. “Recurrent Inference Machines for Solving Inverse Problems.” ArXiv:1706.04008 [Cs], June 2017. arXiv.org, http://arxiv.org/abs/1706.04008.

compute_sense_init(kspace, sensitivity_map)[source][source]#
Return type:

Tensor

forward(input_image, masked_kspace, sampling_mask, sensitivity_map=None, previous_state=None, loglikelihood_scaling=None, **kwargs)[source][source]#

Performs forward pass of RIM.

Parameters:
input_image: torch.Tensor

Initial or intermediate guess of input. Has shape (N, height, width, complex=2).

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

previous_state: torch.Tensor
loglikelihood_scaling: torch.Tensor

Float tensor of shape (1,).

Returns:
torch.Tensor
training: bool#
class direct.nn.rim.rim.RIMInit(x_ch, out_ch, channels, dilations, depth=2, multiscale_depth=1)[source][source]#

Bases: Module

Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces zero initializer for the RIM hidden vector. Inspired by [1].

References

[1]

Yu, Fisher, and Vladlen Koltun. “Multi-Scale Context Aggregation by Dilated Convolutions.” ArXiv:1511.07122 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.07122.

forward(x)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

direct.nn.rim.rim_engine module#

class direct.nn.rim.rim_engine.RIMEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

Recurrent Inference Machine Engine.

Module contents#