Source code for direct.nn.unet.unet_3d

# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code for three-dimensional U-Net adapted from the 2D variant."""

from __future__ import annotations

import math

import torch
from torch import nn
from torch.nn import functional as F


[docs] class ConvBlock3D(nn.Module): """3D U-Net convolutional block.""" def __init__(self, in_channels: int, out_channels: int, dropout_probability: float) -> None: """Inits :class:`ConvBlock3D`. Parameters ---------- in_channels : int Number of channels in the input tensor. out_channels : int Number of channels produced by the convolutional layers. dropout_probability : float Dropout probability applied after convolutional layers. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.dropout_probability = dropout_probability self.layers = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm3d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout3d(dropout_probability), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm3d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout3d(dropout_probability), )
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Performs the forward pass of :class:`ConvBlock3D`.. Parameters ---------- input_data : torch.Tensor Input data. Returns ------- torch.Tensor """ return self.layers(input_data)
[docs] class TransposeConvBlock3D(nn.Module): """3D U-Net Transpose Convolutional Block.""" def __init__(self, in_channels: int, out_channels: int) -> None: """Inits :class:`TransposeConvBlock3D`. Parameters ---------- in_channels : int Number of channels in the input tensor. out_channels : int Number of channels produced by the convolutional layers. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.layers = nn.Sequential( nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), nn.InstanceNorm3d(out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True), )
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Performs the forward pass of :class:`TransposeConvBlock3D`. Parameters ---------- input_data : torch.Tensor Input data. Returns ------- torch.Tensor """ return self.layers(input_data)
[docs] class UnetModel3d(nn.Module): """PyTorch implementation of a 3D U-Net model. This class defines a 3D U-Net architecture consisting of down-sampling and up-sampling layers with 3D convolutional blocks. This is an extension to 3D volumes of :class:`direct.nn.unet.unet_2d.UnetModel2d`. """ def __init__( self, in_channels: int, out_channels: int, num_filters: int, num_pool_layers: int, dropout_probability: float, ) -> None: """Inits :class:`UnetModel3d`. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. num_filters : int Number of output channels of the first convolutional layer. num_pool_layers : int Number of down-sampling and up-sampling layers (depth). dropout_probability : float Dropout probability. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_filters = num_filters self.num_pool_layers = num_pool_layers self.dropout_probability = dropout_probability self.down_sample_layers = nn.ModuleList([ConvBlock3D(in_channels, num_filters, dropout_probability)]) ch = num_filters for _ in range(num_pool_layers - 1): self.down_sample_layers += [ConvBlock3D(ch, ch * 2, dropout_probability)] ch *= 2 self.conv = ConvBlock3D(ch, ch * 2, dropout_probability) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] self.up_conv += [ConvBlock3D(ch * 2, ch, dropout_probability)] ch //= 2 self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] self.up_conv += [ nn.Sequential( ConvBlock3D(ch * 2, ch, dropout_probability), nn.Conv3d(ch, out_channels, kernel_size=1, stride=1), ) ]
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Performs forward pass of :class:`UnetModel3d`. Parameters ---------- input_data : torch.Tensor Input tensor of shape (N, in_channels, slice/time, height, width). Returns ------- torch.Tensor Output of shape (N, out_channels, slice/time, height, width). """ stack = [] output, inp_pad = pad_to_pow_of_2(input_data, self.num_pool_layers) # Apply down-sampling layers for _, layer in enumerate(self.down_sample_layers): output = layer(output) stack.append(output) output = F.avg_pool3d(output, kernel_size=2, stride=2, padding=0) output = self.conv(output) # Apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) padding = [0, 0, 0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: padding[1] = 1 if output.shape[-2] != downsample_layer.shape[-2]: padding[3] = 1 if output.shape[-3] != downsample_layer.shape[-3]: padding[5] = 1 if sum(padding) != 0: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) output = conv(output) if sum(inp_pad) != 0: output = output[ :, :, inp_pad[4] : output.shape[2] - inp_pad[5], inp_pad[2] : output.shape[3] - inp_pad[3], inp_pad[0] : output.shape[4] - inp_pad[1], ] return output
[docs] class NormUnetModel3d(nn.Module): """Implementation of a Normalized U-Net model for 3D data. This is an extension to 3D volumes of :class:`direct.nn.unet.unet_2d.NormUnetModel2d`. """ def __init__( self, in_channels: int, out_channels: int, num_filters: int, num_pool_layers: int, dropout_probability: float, norm_groups: int = 2, ) -> None: """Inits :class:`NormUnetModel3D`. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. num_filters : int Number of output channels of the first convolutional layer. num_pool_layers : int Number of down-sampling and up-sampling layers (depth). dropout_probability : float Dropout probability. norm_groups: int, Number of normalization groups. """ super().__init__() self.unet3d = UnetModel3d( in_channels=in_channels, out_channels=out_channels, num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, ) self.norm_groups = norm_groups
[docs] @staticmethod def norm(input_data: torch.Tensor, groups: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Applies group normalization for 3D data. Parameters ---------- input_data : torch.Tensor The input tensor to normalize. groups : int The number of groups to divide the tensor into for normalization. Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing the normalized tensor, the mean, and the standard deviation used for normalization. """ # Group norm b, c, z, h, w = input_data.shape input_data = input_data.reshape(b, groups, -1) mean = input_data.mean(-1, keepdim=True) std = input_data.std(-1, keepdim=True) output = (input_data - mean) / std output = output.reshape(b, c, z, h, w) return output, mean, std
[docs] @staticmethod def unnorm( input_data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, groups: int, ) -> torch.Tensor: """Reverts the normalization applied to the 3D tensor. Parameters ---------- input_data : torch.Tensor The normalized tensor to revert normalization on. mean : torch.Tensor The mean used during normalization. std : torch.Tensor The standard deviation used during normalization. groups : int The number of groups the tensor was divided into during normalization. Returns ------- torch.Tensor The tensor after reverting the normalization. """ b, c, z, h, w = input_data.shape input_data = input_data.reshape(b, groups, -1) return (input_data * std + mean).reshape(b, c, z, h, w)
[docs] @staticmethod def pad( input_data: torch.Tensor, ) -> tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]]: """Applies padding to the input 3D tensor to ensure its dimensions are multiples of 16. Parameters ---------- input_data : torch.Tensor The input tensor to pad. Returns ------- tuple[torch.Tensor, tuple[list[int], list[int], int, int, list[int], list[int]]] A tuple containing the padded tensor and a tuple with the padding applied to each dimension (height, width, depth) and the target dimensions after padding. """ _, _, z, h, w = input_data.shape w_mult = ((w - 1) | 15) + 1 h_mult = ((h - 1) | 15) + 1 z_mult = ((z - 1) | 15) + 1 w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] z_pad = [math.floor((z_mult - z) / 2), math.ceil((z_mult - z) / 2)] output = F.pad(input_data, w_pad + h_pad + z_pad) return output, (h_pad, w_pad, z_pad, h_mult, w_mult, z_mult)
[docs] @staticmethod def unpad( input_data: torch.Tensor, h_pad: list[int], w_pad: list[int], z_pad: list[int], h_mult: int, w_mult: int, z_mult: int, ) -> torch.Tensor: """Removes padding from the 3D input tensor, reverting it to its original dimensions before padding was applied. This method is typically used after the model has processed the padded input. Parameters ---------- input_data : torch.Tensor The tensor from which padding will be removed. h_pad : list[int] Padding applied to the height, specified as [top, bottom]. w_pad : list[int] Padding applied to the width, specified as [left, right]. z_pad : list[int] Padding applied to the depth, specified as [front, back]. h_mult : int The height as computed in the `pad` method. w_mult : int The width as computed in the `pad` method. z_mult : int The depth as computed in the `pad` method. Returns ------- torch.Tensor The tensor with padding removed, restored to its original dimensions. """ return input_data[..., z_pad[0] : z_mult - z_pad[1], h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Performs the forward pass of :class:`NormUnetModel3D`. Parameters ---------- input_data : torch.Tensor Input tensor of shape (N, in_channels, slice/time, height, width). Returns ------- torch.Tensor Output of shape (N, out_channels, slice/time, height, width). """ output, mean, std = self.norm(input_data, self.norm_groups) output, pad_sizes = self.pad(output) output = self.unet3d(output) output = self.unpad(output, *pad_sizes) output = self.unnorm(output, mean, std, self.norm_groups) return output
[docs] def pad_to_pow_of_2(inp: torch.Tensor, k: int) -> tuple[torch.Tensor, list[int]]: """Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2. This is necessary for certain operations in the 3D U-Net architecture to maintain dimensionality. Parameters ---------- inp : torch.Tensor The input tensor to be padded. k : int The exponent to which the base of 2 is raised to determine the padding. Used to calculate the target dimension size as a power of 2. Returns ------- tuple[torch.Tensor, list[int]] A tuple containing the padded tensor and a list of padding applied to each spatial dimension in the format [depth_front, depth_back, height_top, height_bottom, width_left, width_right]. Examples -------- >>> inp = torch.rand(1, 1, 15, 15, 15) # A random tensor with shape [1, 1, 15, 15, 15] >>> padded_inp, padding = pad_to_pow_of_2(inp, 4) >>> print(padded_inp.shape, padding) torch.Size([...]), [1, 1, 1, 1, 1, 1] """ diffs = [_ - 2**k for _ in inp.shape[2:]] padding = [0, 0, 0, 0, 0, 0] for i, diff in enumerate(diffs[::-1]): if diff < 1: padding[2 * i] = abs(diff) // 2 padding[2 * i + 1] = abs(diff) - padding[2 * i] if sum(padding) > 0: inp = F.pad(inp, padding) return inp, padding