Source code for direct.nn.conjgradnet.conjgradnet_engine

# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch import nn

from direct.config import BaseConfig
from direct.nn.mri_models import MRIModelEngine


[docs] class ConjGradNetEngine(MRIModelEngine): def __init__( self, cfg: BaseConfig, model: nn.Module, device: str, forward_operator: Optional[Callable] = None, backward_operator: Optional[Callable] = None, mixed_precision: bool = False, **models: nn.Module, ): """Inits :class:`ConjGradNetEngine`. Parameters ---------- cfg: BaseConfig Configuration file. model: nn.Module Model. device: str Device. Can be "cuda:{idx}" or "cpu". forward_operator: Callable, optional The forward operator. Default: None. backward_operator: Callable, optional The backward operator. Default: None. mixed_precision: bool Use mixed precision. Default: False. **models: nn.Module Additional models. """ super().__init__( cfg, model, device, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, **models, )
[docs] def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor, None]: output_image = self.model( masked_kspace=data["masked_kspace"], sampling_mask=data["sampling_mask"], sensitivity_map=data["sensitivity_map"], ) # shape (batch, height, width) output_kspace = None return output_image, output_kspace