direct.nn.ssl package#
Submodules#
direct.nn.ssl.mri_models module#
SSL MRI model engines of DIRECT.
- class direct.nn.ssl.mri_models.JSSLMRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
SSLMRIModelEngine
Base Engine for JSSL MRI models.
This engine is used for training models that are trained with joint supervised and self-supervised learning (JSSL), based on the work of Yiasemis et al [1].
During training, for self-supervised samples the loss is computed as in
SSLMRIModelEngine
and for supervised samples the loss is computed as normal supervised MRI learning.During inference, output is computed as \((\mathbb{1} - U)f_{\theta}(\tilde{y}) + \tilde{y}\).
References
[1]Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.
- class direct.nn.ssl.mri_models.SSLMRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
Base Engine for SSL MRI models.
This engine is used for training models that are trained with self-supervised learning. During training, the loss is computed as follows:
\[\mathcal{L}\big(\mathcal{A}_{\text{tar}}(x_{\text{out}}), y_{\text{tar}}\big)\]where \(x_{\text{out}}=f_{\theta}(y_{\text{inp}})\) and \(y_{\text{inp}} + y_{\text{tar}}=\tilde{y}\) are splits of the original measured k-space \(\tilde{y}\) via two (disjoint or not) sub-sampling operators \(y_{\text{inp}}=U_{\text{inp}}(\tilde{y})\) and \(y_{\text{tar}}=U_{\text{tar}}(\tilde{y})\) and \(U_{\text{inp}} + U_{\text{tar}} = U\), where \(U\) is the original sub-sampling operator.
During inference, output is computed as \((\mathbb{1} - U)f_{\theta}(\tilde{y}) + \tilde{y}\).
- abstract forward_function(data)[source][source]#
Must be implemented by child classes.
- Parameters:
- data: dict[str, Any]
- Raises:
- NotImplementedError
Must be implemented by child class.
- Return type:
tuple
[Optional
[Tensor
],Optional
[Tensor
]]
- log_first_training_example_and_model(data)[source][source]#
Logs the first training example for SSL-based MRI models.
This differs from the corresponding method of the base
MRIModelEngine
as it requires the input and target sampling masks to be logged as well and to create the actual sampling mask, if SSL is used.- Parameters:
- data: dict[str, Any]
Dictionary containing the data. The dictionary should contain the following keys:
“filename”: Filename of the data.
“slice_no”: Slice number of the data.
“input_sampling_mask”: Sampling mask for the input k-space if SSL is used.
“target_sampling_mask”: Sampling mask for the target k-space if SSL is used.
“sampling_mask”: Sampling mask if SSL is not used.
“target”: Target image. This is the reconstruction of the target k-space (i.e. subsampled using the target_sampling_mask).
“initial_image”: Initial image.
- Return type:
None