direct.nn.varnet package#
Submodules#
direct.nn.varnet.config module#
- class direct.nn.varnet.config.EndToEndVarNetConfig(model_name: str = '???', engine_name: str | None = None, num_layers: int = 8, regularizer_num_filters: int = 18, regularizer_num_pull_layers: int = 4, regularizer_dropout: float = 0.0)[source][source]#
Bases:
ModelConfig
-
num_layers:
int
= 8#
-
regularizer_dropout:
float
= 0.0#
-
regularizer_num_filters:
int
= 18#
-
regularizer_num_pull_layers:
int
= 4#
-
num_layers:
direct.nn.varnet.varnet module#
- class direct.nn.varnet.varnet.EndToEndVarNet(forward_operator, backward_operator, num_layers, regularizer_num_filters=18, regularizer_num_pull_layers=4, regularizer_dropout=0.0, in_channels=2, **kwargs)[source][source]#
Bases:
Module
End-to-End Variational Network based on [1].
References
[1]Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688.
- forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#
Performs the forward pass of
EndToEndVarNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- Returns:
- kspace_prediction: torch.Tensor
K-space prediction of shape (N, coil, height, width, complex=2).
- Return type:
Tensor
-
training:
bool
#
- class direct.nn.varnet.varnet.EndToEndVarNetBlock(forward_operator, backward_operator, regularizer_model)[source][source]#
Bases:
Module
End-to-End Variational Network block.
- forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map)[source][source]#
Performs the forward pass of
EndToEndVarNetBlock
.- Parameters:
- current_kspace: torch.Tensor
Current k-space prediction of shape (N, coil, height, width, complex=2).
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- Returns:
- torch.Tensor
Next k-space prediction of shape (N, coil, height, width, complex=2).
- Return type:
Tensor
-
training:
bool
#
direct.nn.varnet.varnet_engine module#
Engines for End-to-End Variational Network model.
Includes supervised, self-supervised and joint supervised and self-supervised learning engines.
- class direct.nn.varnet.varnet_engine.EndToEndVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
End-to-End Variational Network Engine.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.varnet.varnet_engine.EndToEndVarNetJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
JSSLMRIModelEngine
Joint Supervised and Self-supervised Learning End-to-End Variational Network Engine.
Used for supplementary experiments for End-to-End Variational Network model with JSLL in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
EndToEndVarNetJSSLEngine
.- Parameters:
- datadict[str, Any]
Data dictionary. Should contain the following keys: - “is_ssl” boolean tensor indicating if training is SSL - “input_kspace” if training and training is SSL, “masked_kspace” if inference - “input_sampling_mask” if training and training is SSL, “sampling_mask” if inference - “sensitivity_map”
- Returns:
- tuple[None, torch.Tensor]
None for image and output k-space.
- Return type:
tuple
[None
,Tensor
]
- class direct.nn.varnet.varnet_engine.EndToEndVarNetSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
SSLMRIModelEngine
Self-supervised Learning End-to-End Variational Network Engine.
Used for supplementary experiments for End-to-End Variational Network model with SLL in the JSSL paper [1].
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- forward_function(data)[source][source]#
Forward function for
EndToEndVarNetSSLEngine
.- Parameters:
- datadict[str, Any]
Data dictionary. Should contain the following keys: - “input_kspace” if training, “masked_kspace” if inference - “input_sampling_mask” if training, “sampling_mask” if inference - “sensitivity_map”
- Returns:
- tuple[None, torch.Tensor]
None for image and output k-space.
- Return type:
tuple
[None
,Tensor
]