# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DIRECT module containing utility functions for the transformers models."""
from __future__ import annotations
from math import ceil, floor
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import init
__all__ = ["init_weights", "norm", "pad_to_divisible", "pad_to_square", "unnorm", "unpad_to_original", "DropoutPath"]
[docs]
def pad_to_divisible(x: torch.Tensor, pad_size: tuple[int, ...]) -> tuple[torch.Tensor, tuple[tuple[int, int], ...]]:
"""Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size.
Parameters
----------
x : torch.Tensor
Input tensor of shape (*, spatial_1, spatial_2, ..., spatial_N), where spatial dimensions can vary in number.
pad_size : tuple[int, ...]
Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial dimension.
Returns
-------
tuple
Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in
each spatial dimension.
"""
pads = []
for dim, p_dim in zip(x.shape[-len(pad_size) :], pad_size):
pad_before = (p_dim - dim % p_dim) % p_dim / 2
pads.append((floor(pad_before), ceil(pad_before)))
# Reverse and flatten pads to match torch's expected
# (pad_n_before, pad_n_after, ..., pad_1_before, pad_1_after) format
flat_pads = tuple(val for sublist in pads[::-1] for val in sublist)
x = F.pad(x, flat_pads)
return x, tuple(pads)
[docs]
def unpad_to_original(x: torch.Tensor, *pads: tuple[int, int]) -> torch.Tensor:
"""Remove the padding added to the input tensor.
Parameters
----------
x : torch.Tensor
Input tensor with padded spatial dimensions.
pads : tuple[int, int]
A tuple of (pad_before, pad_after) for each spatial dimension.
Returns
-------
torch.Tensor
Tensor with the padding removed, matching the shape of the original input tensor before padding.
"""
slices = [slice(None)] * (x.ndim - len(pads)) # Keep the batch and channel dimensions
for i, (pad_before, pad_after) in enumerate(pads):
slices.append(slice(pad_before, x.shape[-len(pads) + i] - pad_after))
return x[tuple(slices)]
[docs]
def pad_to_square(
inp: torch.Tensor, factor: float
) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]:
"""Pad a tensor to a square shape with a given factor.
Parameters
----------
inp : torch.Tensor
The input tensor to pad to square shape. Expected shape is (\*, height, width).
factor : float
The factor to which the input tensor will be padded.
Returns
-------
tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]
A tuple of two tensors, the first is the input tensor padded to a square shape, and the
second is the corresponding mask for the padded tensor.
Examples
--------
1.
>>> x = torch.rand(1, 3, 224, 192)
>>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
>>> padded_x.shape, mask.shape
(torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224]))
2.
>>> x = torch.rand(3, 13, 2, 234, 180)
>>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
>>> padded_x.shape, wpad, hpad
(torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3))
"""
channels, h, w = inp.shape[-3:]
# Calculate the maximum size and pad to the next multiple of the factor
x = int(ceil(max(h, w) / float(factor)) * factor)
# Create a tensor of zeros with the maximum size and copy the input tensor into the center
img = torch.zeros(*inp.shape[:-3], channels, x, x, device=inp.device).type_as(inp)
mask = torch.zeros(*((1,) * (img.ndim - 3)), 1, x, x, device=inp.device).type_as(inp)
# Compute the offset and copy the input tensor into the center of the zero tensor
offset_h = (x - h) // 2
offset_w = (x - w) // 2
hpad = (offset_h, offset_h + h)
wpad = (offset_w, offset_w + w)
img[..., hpad[0] : hpad[1], wpad[0] : wpad[1]] = inp.clone()
mask[..., hpad[0] : hpad[1], wpad[0] : wpad[1]].fill_(1.0)
# Return the padded tensor and the corresponding mask, and padding in spatial dimensions
return (
img,
1 - mask,
(wpad[0], wpad[1] - w + (1 if w % 2 != 0 else 0)),
(hpad[0], hpad[1] - h + (1 if h % 2 != 0 else 0)),
)
[docs]
def norm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Normalize the input tensor by subtracting the mean and dividing by the standard deviation
across each channel and pixel for arbitrary spatial dimensions.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number (e.g., 2D, 3D, etc.).
Returns
-------
tuple
Containing the normalized tensor, mean tensor, and standard deviation tensor.
"""
# Flatten spatial dimensions and compute mean and std across them
spatial_dims = x.shape[2:] # Get all spatial dimensions
flattened = x.view(x.shape[0], x.shape[1], -1) # Flatten the spatial dimensions for mean/std calculation
mean = flattened.mean(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims)))
std = flattened.std(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims)))
# Normalize
x = (x - mean) / std
return x, mean, std
[docs]
def unnorm(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
"""Denormalize the input tensor by multiplying by the standard deviation and adding the mean
for arbitrary spatial dimensions.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number.
mean : torch.Tensor
Mean tensor obtained during normalization.
std : torch.Tensor
Standard deviation tensor obtained during normalization.
Returns
-------
torch.Tensor
Tensor with the same shape as the original input tensor, but denormalized.
"""
return x * std + mean
[docs]
def init_weights(m: nn.Module) -> None:
"""Initializes the weights of the network using a truncated normal distribution.
Parameters
----------
m : nn.Module
A module of the network whose weights need to be initialized.
"""
if isinstance(m, nn.Linear):
init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
init.constant_(m.bias, 0)
init.constant_(m.weight, 1.0)
[docs]
class DropoutPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
"""Inits :class:`DropoutPath`.
Parameters
----------
drop_prob : float
Probability of dropping a residual connection. Default: 0.0.
scale_by_keep : bool
Whether to scale the remaining activations by 1 / (1 - drop_prob) to maintain the expected value of
the activations. Default: True.
"""
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
@staticmethod
def _dropout_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
[docs]
def forward(self, x):
return self._dropout_path(x, self.drop_prob, self.training, self.scale_by_keep)