Source code for direct.nn.varnet.varnet_engine

# Copyright (c) DIRECT Contributors

"""Engines for End-to-End Variational Network model.

Includes supervised, self-supervised and joint supervised and self-supervised learning engines.
"""

from __future__ import annotations

from typing import Any, Callable, Optional

import torch
from torch import nn

import direct.data.transforms as T
from direct.config import BaseConfig
from direct.nn.mri_models import MRIModelEngine
from direct.nn.ssl.mri_models import JSSLMRIModelEngine, SSLMRIModelEngine


[docs] class EndToEndVarNetEngine(MRIModelEngine): """End-to-End Variational Network Engine. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ def __init__( self, cfg: BaseConfig, model: nn.Module, device: str, forward_operator: Optional[Callable] = None, backward_operator: Optional[Callable] = None, mixed_precision: bool = False, **models: nn.Module, ): """Inits :class:`EndToEndVarNetEngine`. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ super().__init__( cfg, model, device, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, **models, )
[docs] def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: output_kspace = self.model( masked_kspace=data["masked_kspace"], sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], ) output_image = T.root_sum_of_squares( self.backward_operator(output_kspace, dim=self._spatial_dims), dim=self._coil_dim, ) # shape (batch, height, width) return output_image, output_kspace
[docs] class EndToEndVarNetSSLEngine(SSLMRIModelEngine): """Self-supervised Learning End-to-End Variational Network Engine. Used for supplementary experiments for End-to-End Variational Network model with SLL in the JSSL paper [1]. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. References ---------- .. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856. """ def __init__( self, cfg: BaseConfig, model: nn.Module, device: str, forward_operator: Optional[Callable] = None, backward_operator: Optional[Callable] = None, mixed_precision: bool = False, **models: nn.Module, ): """Inits :class:`EndToEndVarNetSSLEngine`. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ super().__init__( cfg, model, device, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, **models, )
[docs] def forward_function(self, data: dict[str, Any]) -> tuple[None, torch.Tensor]: """Forward function for :class:`EndToEndVarNetSSLEngine`. Parameters ---------- data : dict[str, Any] Data dictionary. Should contain the following keys: - "input_kspace" if training, "masked_kspace" if inference - "input_sampling_mask" if training, "sampling_mask" if inference - "sensitivity_map" Returns ------- tuple[None, torch.Tensor] None for image and output k-space. """ kspace = data["input_kspace"] if self.model.training else data["masked_kspace"] mask = data["input_sampling_mask"] if self.model.training else data["sampling_mask"] output_kspace = self.model( masked_kspace=kspace, sampling_mask=mask, sensitivity_map=data["sensitivity_map"], ) output_image = None return output_image, output_kspace
[docs] class EndToEndVarNetJSSLEngine(JSSLMRIModelEngine): """Joint Supervised and Self-supervised Learning End-to-End Variational Network Engine. Used for supplementary experiments for End-to-End Variational Network model with JSLL in the JSSL paper [1]. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. References ---------- .. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856. """ def __init__( self, cfg: BaseConfig, model: nn.Module, device: str, forward_operator: Optional[Callable] = None, backward_operator: Optional[Callable] = None, mixed_precision: bool = False, **models: nn.Module, ): """Inits :class:`EndToEndVarNetJSSLEngine`. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The forward operator. Default: None. backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ super().__init__( cfg, model, device, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, **models, )
[docs] def forward_function(self, data: dict[str, Any]) -> tuple[None, torch.Tensor]: """Forward function for :class:`EndToEndVarNetJSSLEngine`. Parameters ---------- data : dict[str, Any] Data dictionary. Should contain the following keys: - "is_ssl" boolean tensor indicating if training is SSL - "input_kspace" if training and training is SSL, "masked_kspace" if inference - "input_sampling_mask" if training and training is SSL, "sampling_mask" if inference - "sensitivity_map" Returns ------- tuple[None, torch.Tensor] None for image and output k-space. """ if data["is_ssl"][0] and self.model.training: kspace, mask = data["input_kspace"], data["input_sampling_mask"] else: kspace, mask = data["masked_kspace"], data["sampling_mask"] output_kspace = self.model( masked_kspace=kspace, sampling_mask=mask, sensitivity_map=data["sensitivity_map"], ) output_image = None return output_image, output_kspace