direct.nn.transformers package

Contents

direct.nn.transformers package#

Submodules#

direct.nn.transformers.config module#

class direct.nn.transformers.config.ImageDomainMRIUFormerConfig(model_name: str = '???', engine_name: Optional[str] = None, patch_size: 'int' = 256, embedding_dim: 'int' = 32, encoder_depths: 'tuple[int, ...]' = (2, 2, 2, 2), encoder_num_heads: 'tuple[int, ...]' = (1, 2, 4, 8), bottleneck_depth: 'int' = 2, bottleneck_num_heads: 'int' = 16, win_size: 'int' = 8, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = True, qk_scale: 'Optional[float]' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, drop_path_rate: 'float' = 0.1, patch_norm: 'bool' = True, token_projection: 'AttentionTokenProjectionType' = <AttentionTokenProjectionType.LINEAR: 'linear'>, token_mlp: 'LeWinTransformerMLPTokenType' = <LeWinTransformerMLPTokenType.LEFF: 'leff'>, shift_flag: 'bool' = True, modulator: 'bool' = False, cross_modulator: 'bool' = False, normalized: 'bool' = True)[source][source]#

Bases: ModelConfig

attn_drop_rate: float = 0.0#
bottleneck_depth: int = 2#
bottleneck_num_heads: int = 16#
cross_modulator: bool = False#
drop_path_rate: float = 0.1#
drop_rate: float = 0.0#
embedding_dim: int = 32#
encoder_depths: tuple[int, ...] = (2, 2, 2, 2)#
encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8)#
mlp_ratio: float = 4.0#
modulator: bool = False#
normalized: bool = True#
patch_norm: bool = True#
patch_size: int = 256#
qk_scale: Optional[float] = None#
qkv_bias: bool = True#
shift_flag: bool = True#
token_mlp: LeWinTransformerMLPTokenType = 'leff'#
token_projection: AttentionTokenProjectionType = 'linear'#
win_size: int = 8#
class direct.nn.transformers.config.ImageDomainMRIViT2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320), patch_size: 'tuple[int, int]' = (16, 16))[source][source]#

Bases: MRIViTConfig

average_size: tuple[int, int] = (320, 320)#
patch_size: tuple[int, int] = (16, 16)#
class direct.nn.transformers.config.ImageDomainMRIViT3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320, 320), patch_size: 'tuple[int, int]' = (16, 16, 16))[source][source]#

Bases: MRIViTConfig

average_size: tuple[int, int] = (320, 320, 320)#
patch_size: tuple[int, int] = (16, 16, 16)#
class direct.nn.transformers.config.KSpaceDomainMRIViT2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320), patch_size: 'tuple[int, int]' = (16, 16), compute_per_coil: 'bool' = True)[source][source]#

Bases: MRIViTConfig

average_size: tuple[int, int] = (320, 320)#
compute_per_coil: bool = True#
patch_size: tuple[int, int] = (16, 16)#
class direct.nn.transformers.config.KSpaceDomainMRIViT3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320, 320), patch_size: 'tuple[int, int]' = (16, 16, 16), compute_per_coil: 'bool' = True)[source][source]#

Bases: MRIViTConfig

average_size: tuple[int, int] = (320, 320, 320)#
compute_per_coil: bool = True#
patch_size: tuple[int, int] = (16, 16, 16)#
class direct.nn.transformers.config.MRIViTConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True)[source][source]#

Bases: ModelConfig

attn_drop_rate: float = 0.0#
depth: int = 8#
drop_rate: float = 0.0#
dropout_path_rate: float = 0.0#
embedding_dim: int = 64#
locality_strength: float = 1.0#
mlp_ratio: float = 4.0#
normalized: bool = True#
num_heads: int = 9#
qk_scale: float = None#
qkv_bias: bool = False#
use_gpsa: bool = True#
use_pos_embedding: bool = True#
class direct.nn.transformers.config.UFormerModelConfig(model_name: str = '???', engine_name: Optional[str] = None, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, patch_size: 'int' = 256, embedding_dim: 'int' = 32, encoder_depths: 'tuple[int, ...]' = (2, 2, 2, 2), encoder_num_heads: 'tuple[int, ...]' = (1, 2, 4, 8), bottleneck_depth: 'int' = 2, bottleneck_num_heads: 'int' = 16, win_size: 'int' = 8, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = True, qk_scale: 'Optional[float]' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, drop_path_rate: 'float' = 0.1, patch_norm: 'bool' = True, token_projection: 'AttentionTokenProjectionType' = <AttentionTokenProjectionType.LINEAR: 'linear'>, token_mlp: 'LeWinTransformerMLPTokenType' = <LeWinTransformerMLPTokenType.LEFF: 'leff'>, shift_flag: 'bool' = True, modulator: 'bool' = False, cross_modulator: 'bool' = False, normalized: 'bool' = True)[source][source]#

Bases: ModelConfig

attn_drop_rate: float = 0.0#
bottleneck_depth: int = 2#
bottleneck_num_heads: int = 16#
cross_modulator: bool = False#
drop_path_rate: float = 0.1#
drop_rate: float = 0.0#
embedding_dim: int = 32#
encoder_depths: tuple[int, ...] = (2, 2, 2, 2)#
encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8)#
in_channels: int = 2#
mlp_ratio: float = 4.0#
modulator: bool = False#
normalized: bool = True#
out_channels: Optional[int] = None#
patch_norm: bool = True#
patch_size: int = 256#
qk_scale: Optional[float] = None#
qkv_bias: bool = True#
shift_flag: bool = True#
token_mlp: LeWinTransformerMLPTokenType = 'leff'#
token_projection: AttentionTokenProjectionType = 'linear'#
win_size: int = 8#
class direct.nn.transformers.config.VisionTransformer2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, average_img_size: 'tuple[int, int]' = '???', patch_size: 'tuple[int, int]' = (16, 16))[source][source]#

Bases: MRIViTConfig

average_img_size: tuple[int, int] = '???'#
in_channels: int = 2#
out_channels: Optional[int] = None#
patch_size: tuple[int, int] = (16, 16)#
class direct.nn.transformers.config.VisionTransformer3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, average_img_size: 'tuple[int, int, int]' = '???', patch_size: 'tuple[int, int, int]' = (16, 16, 16))[source][source]#

Bases: MRIViTConfig

average_img_size: tuple[int, int, int] = '???'#
in_channels: int = 2#
out_channels: Optional[int] = None#
patch_size: tuple[int, int, int] = (16, 16, 16)#

direct.nn.transformers.transformers module#

DIRECT Vision Transformer models for MRI reconstruction.

class direct.nn.transformers.transformers.ImageDomainMRIUFormer(forward_operator, backward_operator, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True, **kwargs)[source][source]#

Bases: Module

U-Former model for MRI reconstruction in the image domain.

Parameters:
forward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Forward operator function.

backward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Backward operator function.

patch_sizeint

Size of the patch. Default: 256.

in_channelsint

Number of input channels. Default: 2.

out_channelsint, optional

Number of output channels. Default: None.

embedding_dimint

Size of the feature embedding. Default: 32.

encoder_depthstuple

Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

encoder_num_headstuple

Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

bottleneck_depthint

Default: 16.

bottleneck_num_headsint

Default: 2.

win_sizeint

Window size for the attention mechanism. Default: 8.

mlp_ratiofloat

Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

qkv_biasbool

Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

qk_scalefloat

Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

drop_ratefloat

Dropout rate for the token-level dropout layer. Default: 0.0.

attn_drop_ratefloat

Dropout rate for the attention score matrix. Default: 0.0.

drop_path_ratefloat

Dropout rate for the stochastic depth regularization. Default: 0.1.

patch_normbool

Whether to use normalization for the patch embeddings. Default: True.

token_projectionAttentionTokenProjectionType

Type of token projection. Must be one of [“linear”, “conv”]. Default: AttentionTokenProjectionType.LINEAR.

token_mlpLeWinTransformerMLPTokenType

Type of token-level MLP. Must be one of [“leff”, “mlp”, “ffn”]. Default: LeWinTransformerMLPTokenType.LEFF.

shift_flagbool

Whether to use shift operation in the local attention mechanism. Default: True.

modulatorbool

Whether to use a modulator in the attention mechanism. Default: False.

cross_modulatorbool

Whether to use cross-modulation in the attention mechanism. Default: False.

normalizedbool

Whether to apply normalization before and denormalization after the forward pass. Default: True.

forward(masked_kspace, sensitivity_map)[source][source]#

Forward pass of ImageDomainMRIUFormer.

Return type:

Tensor

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2)

Returns:
outtorch.Tensor

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.ImageDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source][source]#

Bases: Module

Vision Transformer for MRI reconstruction in 2D.

Parameters:
forward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Forward operator function.

backward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Backward operator function.

average_sizeint or tuple[int, int]

The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

patch_sizeint or tuple[int, int]

The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsabool, optional

Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map)[source][source]#

Forward pass of ImageDomainMRIViT2D.

Return type:

Tensor

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2)

Returns:
outtorch.Tensor

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.ImageDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source][source]#

Bases: Module

Vision Transformer for MRI reconstruction in 3D.

Parameters:
forward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Forward operator function.

backward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Backward operator function.

average_sizeint or tuple[int, int, int]

The average size of the input image. If an int is provided, this will be defined as (average_size, average_size, average_size). Default: 320.

patch_sizeint or tuple[int, int, int]

The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsabool, optional

Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map)[source][source]#

Forward pass of ImageDomainMRIViT3D.

Return type:

Tensor

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, slice/time, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)

Returns:
outtorch.Tensor

The output tensor of shape (N, slice/time, height, width, complex=2).

class direct.nn.transformers.transformers.KSpaceDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source][source]#

Bases: Module

Vision Transformer for MRI reconstruction in 2D in k-space.

Parameters:
forward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Forward operator function.

backward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Backward operator function.

average_sizeint or tuple[int, int]

The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

patch_sizeint or tuple[int, int]

The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsabool, optional

Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#

Forward pass of KSpaceDomainMRIViT2D.

Return type:

Tensor

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2)

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

Returns:
outtorch.Tensor

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.KSpaceDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source][source]#

Bases: Module

Vision Transformer for MRI reconstruction in 3D in k-space.

Parameters:
forward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Forward operator function.

backward_operatorCallable[[tuple[Any, …]], torch.Tensor]

Backward operator function.

average_sizeint or tuple[int, int]

The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

patch_sizeint or tuple[int, int]

The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsabool, optional

Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#

Forward pass of KSpaceDomainMRIViT3D.

Return type:

Tensor

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, slice/time, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1).

Returns:
outtorch.Tensor

The output tensor of shape (N, slice/time height, width, complex=2).

direct.nn.transformers.transformers_engine module#

DIRECT MRI transformer-based model engines.

class direct.nn.transformers.transformers_engine.ImageDomainMRIUFormerEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: ImageDomainMRIViTEngine

MRI U-Former Model Engine for Image Domain.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: ImageDomainMRIViTEngine

MRI ViT Model Engine for Image Domain 2D.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: ImageDomainMRIViTEngine

MRI ViT Model Engine for Image Domain 3D.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

MRI ViT Model Engine for Image Domain.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

forward_function(data)[source][source]#

Forward function for ImageDomainMRIViTEngine.

Parameters:
datadict[str, Any]

Input data.

Returns:
tuple[torch.Tensor, torch.Tensor]

Output image and output k-space.

Return type:

tuple[Tensor, Tensor]

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: KSpaceDomainMRIViTEngine

MRI ViT Model Engine for K-Space Domain 2D.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: KSpaceDomainMRIViTEngine

MRI ViT Model Engine for K-Space Domain 3D.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

MRI ViT Model Engine for K-Space Domain.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

forward_function(data)[source][source]#

Forward function for KSpaceDomainMRIViTEngine.

Parameters:
datadict[str, Any]

Input data.

Returns:
tuple[torch.Tensor, torch.Tensor]

Output image and output k-space.

Return type:

tuple[Tensor, Tensor]

direct.nn.transformers.uformer module#

U-Former model [1] implementation.

Adapted from [2].

References#

[1]

Wang, Zhendong, et al. “Uformer: A general u-shaped transformer for image restoration.” Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.

class direct.nn.transformers.uformer.AttentionTokenProjectionType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#

Bases: DirectEnum

CONV = 'conv'#
LINEAR = 'linear'#
class direct.nn.transformers.uformer.LeWinTransformerMLPTokenType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#

Bases: DirectEnum

FFN = 'ffn'#
LEFF = 'leff'#
MLP = 'mlp'#
class direct.nn.transformers.uformer.UFormer(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False)[source][source]#

Bases: Module

U-Former model based on [1], code originally implemented in [2].

Parameters:
patch_sizeint

Size of the patch. Default: 256.

in_channelsint

Number of input channels. Default: 2.

out_channelsint, optional

Number of output channels. Default: None.

embedding_dimint

Size of the feature embedding. Default: 32.

encoder_depthstuple

Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

encoder_num_headstuple

Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

bottleneck_depthint

Default: 16.

bottleneck_num_headsint

Default: 2.

win_sizeint

Window size for the attention mechanism. Default: 8.

mlp_ratiofloat

Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

qkv_biasbool

Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

qk_scalefloat

Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

drop_ratefloat

Dropout rate for the token-level dropout layer. Default: 0.0.

attn_drop_ratefloat

Dropout rate for the attention score matrix. Default: 0.0.

drop_path_ratefloat

Dropout rate for the stochastic depth regularization. Default: 0.1.

patch_normbool

Whether to use normalization for the patch embeddings. Default: True.

token_projectionAttentionTokenProjectionType

Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

token_mlpLeWinTransformerMLPTokenType

Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

shift_flagbool

Whether to use shift operation in the local attention mechanism. Default: True.

modulatorbool

Whether to use a modulator in the attention mechanism. Default: False.

cross_modulatorbool

Whether to use cross-modulation in the attention mechanism. Default: False.

**kwargs: Other keyword arguments to pass to the parent constructor.

References

[1]

Wang, Zhendong, et al. “Uformer: A general u-shaped transformer for image restoration.” Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.

extra_repr()[source][source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

forward(x, mask=None)[source][source]#

Performs forward pass of UFormer.

Parameters:
xtorch.Tensor

Input tensor.

masktorch.Tensor, optional

Mask tensor. Default: None.

Returns:
torch.Tensor
Return type:

Tensor

no_weight_decay()[source][source]#
no_weight_decay_keywords()[source][source]#
class direct.nn.transformers.uformer.UFormerModel(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source][source]#

Bases: Module

U-Former model with normalization and padding operations.

Parameters:
patch_sizeint

Size of the patch. Default: 256.

in_channelsint

Number of input channels. Default: 2.

out_channelsint, optional

Number of output channels. Default: None.

embedding_dimint

Size of the feature embedding. Default: 32.

encoder_depthstuple

Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

encoder_num_headstuple

Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

bottleneck_depthint

Default: 16.

bottleneck_num_headsint

Default: 2.

win_sizeint

Window size for the attention mechanism. Default: 8.

mlp_ratiofloat

Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

qkv_biasbool

Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

qk_scalefloat

Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

drop_ratefloat

Dropout rate for the token-level dropout layer. Default: 0.0.

attn_drop_ratefloat

Dropout rate for the attention score matrix. Default: 0.0.

drop_path_ratefloat

Dropout rate for the stochastic depth regularization. Default: 0.1.

patch_normbool

Whether to use normalization for the patch embeddings. Default: True.

token_projectionAttentionTokenProjectionType

Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

token_mlpLeWinTransformerMLPTokenType

Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

shift_flagbool

Whether to use shift operation in the local attention mechanism. Default: True.

modulatorbool

Whether to use a modulator in the attention mechanism. Default: False.

cross_modulatorbool

Whether to use cross-modulation in the attention mechanism. Default: False.

normalizedbool

Whether to apply normalization before and denormalization after the forward pass. Default: True.

**kwargs: Other keyword arguments to pass to the parent constructor.
forward(x, mask=None)[source][source]#

Performs forward pass of UFormer.

Parameters:
xtorch.Tensor
masktorch.Tensor, optional
Returns:
torch.Tensor
Return type:

Tensor

direct.nn.transformers.utils module#

DIRECT module containing utility functions for the transformers models.

class direct.nn.transformers.utils.DropoutPath(drop_prob=0.0, scale_by_keep=True)[source][source]#

Bases: Module

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

extra_repr()[source][source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

direct.nn.transformers.utils.init_weights(m)[source][source]#

Initializes the weights of the network using a truncated normal distribution.

Parameters:
mnn.Module

A module of the network whose weights need to be initialized.

Return type:

None

direct.nn.transformers.utils.norm(x)[source][source]#

Normalize the input tensor by subtracting the mean and dividing by the standard deviation across each channel and pixel for arbitrary spatial dimensions.

Parameters:
xtorch.Tensor

Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number (e.g., 2D, 3D, etc.).

Returns:
tuple

Containing the normalized tensor, mean tensor, and standard deviation tensor.

Return type:

tuple[Tensor, Tensor, Tensor]

direct.nn.transformers.utils.pad_to_divisible(x, pad_size)[source][source]#

Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size.

Parameters:
xtorch.Tensor

Input tensor of shape (*, spatial_1, spatial_2, …, spatial_N), where spatial dimensions can vary in number.

pad_sizetuple[int, …]

Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial dimension.

Returns:
tuple

Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in each spatial dimension.

Return type:

tuple[Tensor, tuple[tuple[int, int], ...]]

direct.nn.transformers.utils.pad_to_square(inp, factor)[source][source]#

Pad a tensor to a square shape with a given factor.

Parameters:
inptorch.Tensor

The input tensor to pad to square shape. Expected shape is (*, height, width).

factorfloat

The factor to which the input tensor will be padded.

Returns:
tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]

A tuple of two tensors, the first is the input tensor padded to a square shape, and the second is the corresponding mask for the padded tensor.

Return type:

tuple[Tensor, Tensor, tuple[int, int], tuple[int, int]]

Examples

  1. >>> x = torch.rand(1, 3, 224, 192)
    >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
    >>> padded_x.shape, mask.shape
    (torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224]))
    
  2. >>> x =  torch.rand(3, 13, 2, 234, 180)
    >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
    >>> padded_x.shape, wpad, hpad
    (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3))
    
direct.nn.transformers.utils.unnorm(x, mean, std)[source][source]#

Denormalize the input tensor by multiplying by the standard deviation and adding the mean for arbitrary spatial dimensions.

Parameters:
xtorch.Tensor

Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number.

meantorch.Tensor

Mean tensor obtained during normalization.

stdtorch.Tensor

Standard deviation tensor obtained during normalization.

Returns:
torch.Tensor

Tensor with the same shape as the original input tensor, but denormalized.

Return type:

Tensor

direct.nn.transformers.utils.unpad_to_original(x, *pads)[source][source]#

Remove the padding added to the input tensor.

Parameters:
xtorch.Tensor

Input tensor with padded spatial dimensions.

padstuple[int, int]

A tuple of (pad_before, pad_after) for each spatial dimension.

Returns:
torch.Tensor

Tensor with the padding removed, matching the shape of the original input tensor before padding.

Return type:

Tensor

direct.nn.transformers.vit module#

DIRECT Vision Transformer module.

Implementation of Vision Transformer model [1, 2]_ in PyTorch.

Code borrowed from [3] which uses code from timm [4].

References#

[1]

Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, http://arxiv.org/abs/2010.11929, (2021).

[2]

Steiner, A., Kolesnikov, A., Zhai, X., Wightman, R., Uszkoreit, J., Beyer, L.: How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers, http://arxiv.org/abs/2106.10270, (2022).

class direct.nn.transformers.vit.VisionTransformer2D(average_img_size=320, patch_size=16, in_channels=2, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source][source]#

Bases: VisionTransformer

Vision Transformer model for 2D data.

Parameters:
average_img_sizeint or tuple[int, int]

The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_img_size, average_img_size) for 2D and (average_img_size, average_img_size, average_img_size) for 3D. Default: 320.

patch_sizeint or tuple[int, int]

The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

in_channelsint

Number of input channels. Default: COMPLEX_SIZE.

out_channelsint or None

Number of output channels. If None, this will be set to in_channels. Default: None.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsa: bool

Whether to use GPSA layer. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

seq2img(x, img_size)[source][source]#

Converts the sequence patches tensor to an image tensor.

Parameters:
xtorch.Tensor

The sequence tensor.

img_sizetuple[int, …]

The size of the image tensor.

Returns:
torch.Tensor

The image tensor.

Return type:

Tensor

class direct.nn.transformers.vit.VisionTransformer3D(average_img_size=320, patch_size=16, in_channels=2, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source][source]#

Bases: VisionTransformer

Vision Transformer model for 3D data.

Parameters:
average_img_sizeint or tuple[int, int, int]

The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size, average_img_size). Default: 320.

patch_sizeint or tuple[int, int, int]

The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

in_channelsint

Number of input channels. Default: COMPLEX_SIZE.

out_channelsint or None

Number of output channels. If None, this will be set to in_channels. Default: None.

embedding_dimint

Dimension of the output embedding.

depthint

Number of transformer blocks.

num_headsint

Number of attention heads.

mlp_ratiofloat

The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

qkv_biasbool

Whether to add bias to the query, key, and value projections. Default: False.

qk_scalefloat

The scale factor for the query-key dot product. Default: None.

drop_ratefloat

The dropout probability for all dropout layers except dropout_path. Default: 0.0.

attn_drop_ratefloat

The dropout probability for the attention layer. Default: 0.0.

dropout_path_ratefloat

The dropout probability for the dropout path. Default: 0.0.

use_gpsa: bool

Whether to use GPSA layer. Default: True.

locality_strengthfloat

The strength of the locality assumption in initialization. Default: 1.0.

use_pos_embeddingbool

Whether to use positional embeddings. Default: True.

normalizedbool

Whether to normalize the input tensor. Default: True.

seq2img(x, img_size)[source][source]#

Converts the sequence of 3D patches to a 3D image tensor.

Parameters:
xtorch.Tensor

The sequence tensor, where each entry corresponds to a flattened 3D patch.

img_sizetuple of ints

The size of the 3D image tensor (depth, height, width).

Returns:
torch.Tensor

The reconstructed 3D image tensor.

Return type:

Tensor

Module contents#

DIRECT transformers models.