Source code for direct.nn.kikinet.kikinet

# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional

import torch
import torch.nn as nn

import direct.data.transforms as T
from direct.nn.conv.conv import Conv2d
from direct.nn.crossdomain.multicoil import MultiCoil
from direct.nn.didn.didn import DIDN
from direct.nn.mwcnn.mwcnn import MWCNN
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d


[docs] class KIKINet(nn.Module): """Based on KIKINet implementation [1]_. Modified to work with multi-coil k-space data. References ---------- .. [1] Eo, Taejoon, et al. “KIKI-Net: Cross-Domain Convolutional Neural Networks for Reconstructing Undersampled Magnetic Resonance Images.” Magnetic Resonance in Medicine, vol. 80, no. 5, Nov. 2018, pp. 2188–201. PubMed, https://doi.org/10.1002/mrm.27201. """ def __init__( self, forward_operator: Callable, backward_operator: Callable, image_model_architecture: str = "MWCNN", kspace_model_architecture: str = "DIDN", num_iter: int = 2, normalize: bool = False, **kwargs, ): """Inits :class:`KIKINet`. Parameters ---------- forward_operator: Callable Forward Operator. backward_operator: Callable Backward Operator. image_model_architecture: str Image model architecture. Currently only implemented for MWCNN and (NORM)UNET. Default: 'MWCNN'. kspace_model_architecture: str Kspace model architecture. Currently only implemented for CONV and DIDN and (NORM)UNET. Default: 'DIDN'. num_iter: int Number of unrolled iterations. normalize: bool If true, input is normalised based on input scaling_factor. kwargs: dict Keyword arguments for model architectures. """ super().__init__() image_model: nn.Module if image_model_architecture == "MWCNN": image_model = MWCNN( input_channels=2, first_conv_hidden_channels=kwargs.get("image_mwcnn_hidden_channels", 32), num_scales=kwargs.get("image_mwcnn_num_scales", 4), bias=kwargs.get("image_mwcnn_bias", False), batchnorm=kwargs.get("image_mwcnn_batchnorm", False), ) elif image_model_architecture in ["UNET", "NORMUNET"]: unet = UnetModel2d if image_model_architecture == "UNET" else NormUnetModel2d image_model = unet( in_channels=2, out_channels=2, num_filters=kwargs.get("image_unet_num_filters", 8), num_pool_layers=kwargs.get("image_unet_num_pool_layers", 4), dropout_probability=kwargs.get("image_unet_dropout_probability", 0.0), ) else: raise NotImplementedError( f"XPDNet is currently implemented only with image_model_architecture == 'MWCNN', 'UNET' or 'NORMUNET." f"Got {image_model_architecture}." ) kspace_model: nn.Module if kspace_model_architecture == "CONV": kspace_model = Conv2d( in_channels=2, out_channels=2, hidden_channels=kwargs.get("kspace_conv_hidden_channels", 16), n_convs=kwargs.get("kspace_conv_n_convs", 4), batchnorm=kwargs.get("kspace_conv_batchnorm", False), ) elif kspace_model_architecture == "DIDN": kspace_model = DIDN( in_channels=2, out_channels=2, hidden_channels=kwargs.get("kspace_didn_hidden_channels", 16), num_dubs=kwargs.get("kspace_didn_num_dubs", 6), num_convs_recon=kwargs.get("kspace_didn_num_convs_recon", 9), ) elif kspace_model_architecture in ["UNET", "NORMUNET"]: unet = UnetModel2d if kspace_model_architecture == "UNET" else NormUnetModel2d kspace_model = unet( in_channels=2, out_channels=2, num_filters=kwargs.get("kspace_unet_num_filters", 8), num_pool_layers=kwargs.get("kspace_unet_num_pool_layers", 4), dropout_probability=kwargs.get("kspace_unet_dropout_probability", 0.0), ) else: raise NotImplementedError( f"XPDNet is currently implemented for kspace_model_architecture == 'CONV', 'DIDN'," f" 'UNET' or 'NORMUNET'. Got kspace_model_architecture == {kspace_model_architecture}." ) self._coil_dim = 1 self._complex_dim = -1 self._spatial_dims = (2, 3) self.image_model_list = nn.ModuleList([image_model] * num_iter) self.kspace_model_list = nn.ModuleList([MultiCoil(kspace_model, self._coil_dim)] * num_iter) self.forward_operator = forward_operator self.backward_operator = backward_operator self.num_iter = num_iter self.normalize = normalize
[docs] def forward( self, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, scaling_factor: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Computes forward pass of :class:`KIKINet`. Parameters ---------- masked_kspace: torch.Tensor Masked k-space of shape (N, coil, height, width, complex=2). sampling_mask: torch.Tensor Sampling mask of shape (N, 1, height, width, 1). sensitivity_map: torch.Tensor Sensitivity map of shape (N, coil, height, width, complex=2). scaling_factor: Optional[torch.Tensor] Scaling factor of shape (N,). If None, no scaling is applied. Default: None. Returns ------- image: torch.Tensor Output image of shape (N, height, width, complex=2). """ kspace = masked_kspace.clone() if self.normalize and scaling_factor is not None: kspace = kspace / (scaling_factor**2).view(-1, 1, 1, 1, 1) for idx in range(self.num_iter): kspace = self.kspace_model_list[idx](kspace.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2) image = T.reduce_operator( self.backward_operator( torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace, ).contiguous(), self._spatial_dims, ), sensitivity_map, self._coil_dim, ) image = self.image_model_list[idx](image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) if idx < self.num_iter - 1: kspace = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=image.dtype).to(image.device), self.forward_operator( T.expand_operator(image, sensitivity_map, self._coil_dim), dim=self._spatial_dims ), ) if self.normalize and scaling_factor is not None: image = image * (scaling_factor**2).view(-1, 1, 1, 1) return image