direct.nn.rim package#
Submodules#
direct.nn.rim.config module#
- class direct.nn.rim.config.RIMConfig(model_name: str = '???', engine_name: str | None = None, hidden_channels: int = 16, length: int = 8, depth: int = 2, steps: int = 1, no_parameter_sharing: bool = False, instance_norm: bool = False, dense_connect: bool = False, whiten_input: bool = False, replication_padding: bool = True, image_initialization: str = 'zero_filled', scale_loglikelihood: float | None = None, learned_initializer: bool = False, initializer_channels: Tuple[int, ...] = (32, 32, 64, 64), initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4), initializer_multiscale: int = 1, normalized: bool = False)[source][source]#
Bases:
ModelConfig
-
dense_connect:
bool
= False#
-
depth:
int
= 2#
-
image_initialization:
str
= 'zero_filled'#
-
initializer_channels:
Tuple
[int
,...
] = (32, 32, 64, 64)#
-
initializer_dilations:
Tuple
[int
,...
] = (1, 1, 2, 4)#
-
initializer_multiscale:
int
= 1#
-
instance_norm:
bool
= False#
-
learned_initializer:
bool
= False#
-
length:
int
= 8#
-
no_parameter_sharing:
bool
= False#
-
normalized:
bool
= False#
-
replication_padding:
bool
= True#
-
scale_loglikelihood:
Optional
[float
] = None#
-
steps:
int
= 1#
-
whiten_input:
bool
= False#
-
dense_connect:
direct.nn.rim.rim module#
- class direct.nn.rim.rim.MRILogLikelihood(forward_operator, backward_operator)[source][source]#
Bases:
Module
Defines the MRI loglikelihood assuming one noise vector for the complex images for all coils:
\[\frac{1}{\sigma^2} \sum_{i}^{N_c} {S}_i^{\text{H}} \mathcal{F}^{-1} P^{*} (P \mathcal{F} S_i x_{\tau} - y_{\tau})\]for each time step \(\tau\).
- forward(input_image, masked_kspace, sensitivity_map, sampling_mask, loglikelihood_scaling=None)[source][source]#
Performs forward pass of
MRILogLikelihood
.- Parameters:
- input_image: torch.Tensor
Initial or previous iteration of image with complex first of shape (N, complex, height, width).
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex).
- sensitivity_map: torch.Tensor
Sensitivity Map of shape (N, coil, height, width, complex).
- sampling_mask: torch.Tensor
- loglikelihood_scaling: torch.Tensor
Multiplier for loglikelihood, for instance for the k-space noise, of shape (1,).
- Returns:
- out: torch.Tensor
The MRI Loglikelihood.
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.rim.rim.RIM(forward_operator, backward_operator, hidden_channels, x_channels=2, length=8, depth=1, no_parameter_sharing=True, instance_norm=False, dense_connect=False, skip_connections=True, replication_padding=True, image_initialization='zero_filled', learned_initializer=False, initializer_channels=(32, 32, 64, 64), initializer_dilations=(1, 1, 2, 4), initializer_multiscale=1, normalized=False, **kwargs)[source][source]#
Bases:
Module
Recurrent Inference Machine Module as in [1].
References
[1]Putzky, Patrick, and Max Welling. “Recurrent Inference Machines for Solving Inverse Problems.” ArXiv:1706.04008 [Cs], June 2017. arXiv.org, http://arxiv.org/abs/1706.04008.
- forward(input_image, masked_kspace, sampling_mask, sensitivity_map=None, previous_state=None, loglikelihood_scaling=None, **kwargs)[source][source]#
Performs forward pass of
RIM
.- Parameters:
- input_image: torch.Tensor
Initial or intermediate guess of input. Has shape (N, height, width, complex=2).
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- previous_state: torch.Tensor
- loglikelihood_scaling: torch.Tensor
Float tensor of shape (1,).
- Returns:
- torch.Tensor
-
training:
bool
#
- class direct.nn.rim.rim.RIMInit(x_ch, out_ch, channels, dilations, depth=2, multiscale_depth=1)[source][source]#
Bases:
Module
Learned initializer for RIM, based on multi-scale context aggregation with dilated convolutions, that replaces zero initializer for the RIM hidden vector. Inspired by [1].
References
[1]Yu, Fisher, and Vladlen Koltun. “Multi-Scale Context Aggregation by Dilated Convolutions.” ArXiv:1511.07122 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.07122.
- forward(x)[source][source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool
#
direct.nn.rim.rim_engine module#
- class direct.nn.rim.rim_engine.RIMEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
Recurrent Inference Machine Engine.