direct.nn.ssl package#

Submodules#

direct.nn.ssl.mri_models module#

SSL MRI model engines of DIRECT.

class direct.nn.ssl.mri_models.SSLMRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

Base Engine for SSL MRI models.

This engine is used for training models that are trained with self-supervised learning. During training, the loss is computed as follows:

\[\mathcal{L}\big(\mathcal{A}_{\text{tar}}(x_{\text{out}}), y_{\text{tar}}\big)\]

where \(x_{\text{out}}=f_{\theta}(y_{\text{inp}})\) and \(y_{\text{inp}} + y_{\text{tar}}=\tilde{y}\) are splits of the original measured k-space \(\tilde{y}\) via two (disjoint or not) sub-sampling operators \(y_{\text{inp}}=U_{\text{inp}}(\tilde{y})\) and \(y_{\text{tar}}=U_{\text{tar}}(\tilde{y})\) and \(U_{\text{inp}} + U_{\text{tar}} = U\), where \(U\) is the original sub-sampling operator.

During inference, output is computed as \((\mathbb{1} - U)f_{\theta}(\tilde{y}) + \tilde{y}\).

Note

This engine also implements the log_first_training_example_and_model method to log the first training example which differs from the corresponding method of the base MRIModelEngine.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits SSLMRIModelEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda” or “cpu”.

  • forward_operator (Optional[Callable]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

log_first_training_example_and_model(data)[source]#

Logs the first training example for SSL-based MRI models.

This differs from the corresponding method of the base MRIModelEngine as it requires the input and target sampling masks to be logged as well and to create the actual sampling mask, if SSL is used.

Parameters:
  • data (dict[str, Any]) – dict[str, Any]

  • keys (Dictionary containing the data. The dictionary should contain the following)

  • "filename" (-) – Filename of the data.

  • "slice_no" (-) – Slice number of the data.

  • "input_sampling_mask" (-) – Sampling mask for the input k-space if SSL is used.

  • "target_sampling_mask" (-) – Sampling mask for the target k-space if SSL is used.

  • "sampling_mask" (-) – Sampling mask if SSL is not used.

  • "target" (-) – Target image. This is the reconstruction of the target k-space (i.e. subsampled using

  • target_sampling_mask). (the)

  • "initial_image" (-) – Initial image.

Return type:

None

abstract forward_function(data)[source]#

Must be implemented by child classes.

Parameters:

data (dict[str, Any]) – dict[str, Any]

Return type:

tuple[Optional[Tensor], Optional[Tensor]]

Raises: NotImplementedError

Must be implemented by child class.

class direct.nn.ssl.mri_models.JSSLMRIModelEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: SSLMRIModelEngine

Base Engine for JSSL MRI models.

This engine is used for training models that are trained with joint supervised and self-supervised learning (JSSL), based on the work of Yiasemis et al [1].

During training, for self-supervised samples the loss is computed as in SSLMRIModelEngine and for supervised samples the loss is computed as normal supervised MRI learning.

During inference, output is computed as \((\mathbb{1} - U)f_{\theta}(\tilde{y}) + \tilde{y}\).

References: .. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and

Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits JSSLMRIModelEngine.

Parameters:
  • cfg (BaseConfig) – BaseConfig

  • file. (Configuration)

  • model (Module) – nn.Module

  • Model.

  • device (str) – str

  • "cpu". (Device. Can be "cuda" or)

  • forward_operator (Optional[Callable]) – Callable, optional

  • Default (Use mixed precision.) – None.

  • backward_operator (Optional[Callable]) – Callable, optional

  • Default – None.

  • mixed_precision (bool) – bool

  • Default – False.

  • **models (Module) – nn.Module

  • models. (Additional)

Module contents#