Source code for direct.nn.conjgradnet.conjgradnet_engine

# coding=utf-8
# Copyright (c) DIRECT Contributors

from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch import nn

from direct.config import BaseConfig
from direct.nn.mri_models import MRIModelEngine


[docs] class ConjGradNetEngine(MRIModelEngine): def __init__( self, cfg: BaseConfig, model: nn.Module, device: str, forward_operator: Optional[Callable] = None, backward_operator: Optional[Callable] = None, mixed_precision: bool = False, **models: nn.Module, ): """Inits :class:`ConjGradNetEngine`. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable, optional The forward operator. Default: None. backward_operator: Callable, optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ super().__init__( cfg, model, device, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, **models, )
[docs] def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: output_image = self.model( masked_kspace=data["masked_kspace"], sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], ) # shape (batch, height, width) output_kspace = None return output_image, output_kspace