Adding your own model

Adding your own model#

To add a new model follow the steps below:

  • Implement your custom model under direct/nn/<model_name>/<model_name>.py. For example:

import torch
from torch import nn
from torch.nn import functional as F

class MyMRIModel(nn.Module):
    """My custom MRI model."""

    def __init__(self, param1: param1_type, ...):
        """Inits :class:`MyMRIModel`.

        Parameters
        ----------
        param1 : param1_type
            ...
        ...
        """
        super().__init__()

    def my_method(self, ...) -> ...:
        pass

    @staticmethod
    def my_static_method(...) -> ...:
        pass

    def forward(
        self,
        masked_kspace: torch.Tensor,
        sampling_mask: torch.Tensor,
        sensitivity_map: torch.Tensor,
        ...
    ) -> torch.Tensor:
        """Computes forward pass of :class:`MyMRIModel`.

        Parameters
        ----------
        masked_kspace: torch.Tensor
            Masked k-space of shape (N, coil, height, width, complex=2).
        sampling_mask: torch.Tensor
            Sampling mask of shape (N, 1, height, width, 1).
        sensitivity_map: torch.Tensor
            Sensitivity map of shape (N, coil, height, width, complex=2).
        ...

        Returns
        -------
        out_image: torch.Tensor
            Output image of shape (N, height, width, complex=2).
        ...
        """
  • Implement your custom model’s engine under direct/nn/<model_name>/<model_name>_engine.py. For example:

from __future__ import annotations

from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch import nn

from direct.config import BaseConfig
from direct.nn.mri_models import MRIModelEngine


class MyMRIModelEngine(MRIModelEngine):
    """:class:`MyMRIModel` Engine."""

    def __init__(
        self,
        cfg: BaseConfig,
        model: nn.Module,
        device: str,
        forward_operator: Optional[Callable] = None,
        backward_operator: Optional[Callable] = None,
        mixed_precision: bool = False,
        **models: nn.Module,
    ):
        """Inits :class:`MyMRIModel`."""
        super().__init__(
            cfg,
            model,
            device,
            forward_operator=forward_operator,
            backward_operator=backward_operator,
            mixed_precision=mixed_precision,
            **models,
        )

    def forward_function(self, data: Dict[str, Any]) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
        output_image = self.model(
            masked_kspace=data["masked_kspace"],
            sampling_mask=data["sampling_mask"],
            sensitivity_map=data["sensitivity_map"],
            ...=...
        )
        # ΟR
        output_kspace = self.model(
            masked_kspace=data["masked_kspace"],
            sampling_mask=data["sampling_mask"],
            sensitivity_map=data["sensitivity_map"],
            ...=...
        )
        ...
        return output_image, output_kspace
  • Implement your custom model’s config under direct/nn/<model_name>/config.py. For example:

from dataclasses import dataclass

from direct.config.defaults import ModelConfig


@dataclass
class MyMRIModelConfig(ModelConfig):
    param1: param1_type = param1_default_value
    ...