direct.nn.transformers package

Contents

direct.nn.transformers package#

Submodules#

direct.nn.transformers.config module#

class direct.nn.transformers.config.UFormerModelConfig(model_name='???', engine_name=None, in_channels=2, out_channels=None, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source]#

Bases: ModelConfig

in_channels = 2#
out_channels = None#
patch_size = 256#
embedding_dim = 32#
encoder_depths = (2, 2, 2, 2)#
encoder_num_heads = (1, 2, 4, 8)#
bottleneck_depth = 2#
bottleneck_num_heads = 16#
win_size = 8#
mlp_ratio = 4.0#
qkv_bias = True#
qk_scale = None#
drop_rate = 0.0#
attn_drop_rate = 0.0#
drop_path_rate = 0.1#
patch_norm = True#
token_projection = 'linear'#
token_mlp = 'leff'#
shift_flag = True#
modulator = False#
cross_modulator = False#
normalized = True#
__init__(model_name='???', engine_name=None, in_channels=2, out_channels=None, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)#
class direct.nn.transformers.config.ImageDomainMRIUFormerConfig(model_name='???', engine_name=None, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source]#

Bases: ModelConfig

patch_size = 256#
embedding_dim = 32#
encoder_depths = (2, 2, 2, 2)#
encoder_num_heads = (1, 2, 4, 8)#
bottleneck_depth = 2#
bottleneck_num_heads = 16#
win_size = 8#
mlp_ratio = 4.0#
qkv_bias = True#
qk_scale = None#
drop_rate = 0.0#
attn_drop_rate = 0.0#
drop_path_rate = 0.1#
patch_norm = True#
token_projection = 'linear'#
token_mlp = 'leff'#
shift_flag = True#
modulator = False#
cross_modulator = False#
normalized = True#
__init__(model_name='???', engine_name=None, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)#
class direct.nn.transformers.config.MRIViTConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source]#

Bases: ModelConfig

embedding_dim = 64#
depth = 8#
num_heads = 9#
mlp_ratio = 4.0#
qkv_bias = False#
qk_scale = None#
drop_rate = 0.0#
attn_drop_rate = 0.0#
dropout_path_rate = 0.0#
use_gpsa = True#
locality_strength = 1.0#
use_pos_embedding = True#
normalized = True#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)#
class direct.nn.transformers.config.VisionTransformer2DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, in_channels=2, out_channels=None, average_img_size='???', patch_size=(16, 16))[source]#

Bases: MRIViTConfig

in_channels = 2#
out_channels = None#
average_img_size = '???'#
patch_size = (16, 16)#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, in_channels=2, out_channels=None, average_img_size='???', patch_size=(16, 16))#
class direct.nn.transformers.config.VisionTransformer3DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, in_channels=2, out_channels=None, average_img_size='???', patch_size=(16, 16, 16))[source]#

Bases: MRIViTConfig

in_channels = 2#
out_channels = None#
average_img_size = '???'#
patch_size = (16, 16, 16)#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, in_channels=2, out_channels=None, average_img_size='???', patch_size=(16, 16, 16))#
class direct.nn.transformers.config.ImageDomainMRIViT2DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320), patch_size=(16, 16))[source]#

Bases: MRIViTConfig

average_size = (320, 320)#
patch_size = (16, 16)#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320), patch_size=(16, 16))#
class direct.nn.transformers.config.ImageDomainMRIViT3DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320, 320), patch_size=(16, 16, 16))[source]#

Bases: MRIViTConfig

average_size = (320, 320, 320)#
patch_size = (16, 16, 16)#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320, 320), patch_size=(16, 16, 16))#
class direct.nn.transformers.config.KSpaceDomainMRIViT2DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320), patch_size=(16, 16), compute_per_coil=True)[source]#

Bases: MRIViTConfig

average_size = (320, 320)#
patch_size = (16, 16)#
compute_per_coil = True#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320), patch_size=(16, 16), compute_per_coil=True)#
class direct.nn.transformers.config.KSpaceDomainMRIViT3DConfig(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320, 320), patch_size=(16, 16, 16), compute_per_coil=True)[source]#

Bases: MRIViTConfig

average_size = (320, 320, 320)#
patch_size = (16, 16, 16)#
compute_per_coil = True#
__init__(model_name='???', engine_name=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True, average_size=(320, 320, 320), patch_size=(16, 16, 16), compute_per_coil=True)#

direct.nn.transformers.transformers module#

DIRECT Vision Transformer models for MRI reconstruction.

class direct.nn.transformers.transformers.ImageDomainMRIUFormer(forward_operator, backward_operator, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True, **kwargs)[source]#

Bases: Module

U-Former model for MRI reconstruction in the image domain.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • patch_size (int) – Size of the patch. Default: 256.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Number of layers for the bottleneck of the U-former. Default: 2.

  • bottleneck_num_heads (int) – Number of attention heads for the bottleneck of the U-former. Default: 16.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • normalized (bool) – Whether to apply normalization before and denormalization after the forward pass. Default: True.

  • **kwargs – Additional keyword arguments.

__init__(forward_operator, backward_operator, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True, **kwargs)[source]#

Inits ImageDomainMRIUFormer.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • patch_size (int) – Size of the patch. Default: 256.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Number of layers for the bottleneck of the U-former. Default: 2.

  • bottleneck_num_heads (int) – Number of attention heads for the bottleneck of the U-former. Default: 16.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • normalized (bool) – Whether to apply normalization before and denormalization after the forward pass. Default: True.

  • **kwargs – Additional keyword arguments.

forward(masked_kspace, sensitivity_map)[source]#

Forward pass of ImageDomainMRIUFormer.

Parameters:
  • masked_kspace (Tensor) – Masked k-space of shape (N, coil, height, width, complex=2).

  • sensitivity_map (Tensor) – Sensitivity map of shape (N, coil, height, width, complex=2).

Return type:

Tensor

Returns:

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.ImageDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source]#

Bases: Module

Vision Transformer for MRI reconstruction in 2D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source]#

Inits ImageDomainMRIViT2D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map)[source]#

Forward pass of ImageDomainMRIViT2D.

Parameters:
  • masked_kspace (Tensor) – Masked k-space of shape (N, coil, height, width, complex=2).

  • sensitivity_map (Tensor) – Sensitivity map of shape (N, coil, height, width, complex=2)

Return type:

Tensor

Returns:

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.ImageDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source]#

Bases: Module

Vision Transformer for MRI reconstruction in 3D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_size, average_size, average_size). Default: 320.

  • patch_size (int | tuple[int, int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source]#

Inits ImageDomainMRIViT3D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_size, average_size, average_size). Default: 320.

  • patch_size (int | tuple[int, int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

forward(masked_kspace, sensitivity_map)[source]#

Forward pass of ImageDomainMRIViT3D.

Parameters:
  • masked_kspace (Tensor) – Masked k-space of shape (N, coil, slice/time, height, width, complex=2).

  • sensitivity_map (Tensor) – Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)

Return type:

Tensor

Returns:

The output tensor of shape (N, slice/time, height, width, complex=2).

class direct.nn.transformers.transformers.KSpaceDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source]#

Bases: Module

Vision Transformer for MRI reconstruction in 2D in k-space.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source]#

Inits KSpaceDomainMRIViT2D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

  • compute_per_coil (bool) – Whether to compute the output per coil.

forward(masked_kspace, sensitivity_map, sampling_mask)[source]#

Forward pass of KSpaceDomainMRIViT2D.

Parameters:
  • masked_kspace (Tensor) – Masked k-space of shape (N, coil, height, width, complex=2).

  • sensitivity_map (Tensor) – Sensitivity map of shape (N, coil, height, width, complex=2)

  • sampling_mask (Tensor) – Sampling mask of shape (N, 1, height, width, 1).

Return type:

Tensor

Returns:

The output tensor of shape (N, height, width, complex=2).

class direct.nn.transformers.transformers.KSpaceDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source]#

Bases: Module

Vision Transformer for MRI reconstruction in 3D in k-space.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source]#

Inits KSpaceDomainMRIViT3D.

Parameters:
  • forward_operator (Callable[[tuple[Any, ...]], Tensor]) – Forward operator function.

  • backward_operator (Callable[[tuple[Any, ...]], Tensor]) – Backward operator function.

  • average_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (tuple[int, int]) – Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

  • compute_per_coil (bool) – Whether to compute the output per coil.

forward(masked_kspace, sensitivity_map, sampling_mask)[source]#

Forward pass of KSpaceDomainMRIViT3D.

Parameters:
  • masked_kspace (Tensor) – Masked k-space of shape (N, coil, slice/time, height, width, complex=2).

  • sensitivity_map (Tensor) – Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)

  • sampling_mask (Tensor) – Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1).

Return type:

Tensor

Returns:

The output tensor of shape (N, slice/time height, width, complex=2).

direct.nn.transformers.transformers_engine module#

DIRECT MRI transformer-based model engines.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

MRI ViT Model Engine for Image Domain.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits ImageDomainMRIViTEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

forward_function(data)[source]#

Forward function for ImageDomainMRIViTEngine.

Parameters:

data (dict[str, Any]) – Input data.

Return type:

tuple[Tensor, Tensor]

Returns:

Output image and output k-space.

class direct.nn.transformers.transformers_engine.ImageDomainMRIUFormerEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: ImageDomainMRIViTEngine

MRI U-Former Model Engine for Image Domain.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits ImageDomainMRIUFormerEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: ImageDomainMRIViTEngine

MRI ViT Model Engine for Image Domain 2D.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits ImageDomainMRIViT2DEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

class direct.nn.transformers.transformers_engine.ImageDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: ImageDomainMRIViTEngine

MRI ViT Model Engine for Image Domain 3D.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits ImageDomainMRIViT3DEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

MRI ViT Model Engine for K-Space Domain.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits KSpaceDomainMRIViTEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

forward_function(data)[source]#

Forward function for KSpaceDomainMRIViTEngine.

Parameters:

data (dict[str, Any]) – Input data.

Return type:

tuple[Tensor, Tensor]

Returns:

Output image and output k-space.

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: KSpaceDomainMRIViTEngine

MRI ViT Model Engine for K-Space Domain 2D.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits KSpaceDomainMRIViT2DEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: KSpaceDomainMRIViTEngine

MRI ViT Model Engine for K-Space Domain 3D.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits KSpaceDomainMRIViT3DEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable[[tuple[Any, ...]], Tensor]]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

direct.nn.transformers.uformer module#

U-Former model [1]_ implementation.

Adapted from [2]_.

References

class direct.nn.transformers.uformer.AttentionTokenProjectionType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

CONV = 'conv'#
LINEAR = 'linear'#
class direct.nn.transformers.uformer.LeWinTransformerMLPTokenType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

MLP = 'mlp'#
FFN = 'ffn'#
LEFF = 'leff'#
class direct.nn.transformers.uformer.UFormer(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False)[source]#

Bases: Module

U-Former model based on [1]_, code originally implemented in [2]_.

Parameters:
  • patch_size (int) – Size of the patch. Default: 256.

  • in_channels (int) – Number of input channels. Default: 2.

  • out_channels (Optional[int]) – Number of output channels. Default: None.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Default: 16.

  • bottleneck_num_heads (int) – Default: 2.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • **kwargs – Other keyword arguments to pass to the parent constructor.

  • References

  • Wang (.. [1]) – A general u-shaped transformer for image restoration.” Proceedings of the

  • Zhendong – A general u-shaped transformer for image restoration.” Proceedings of the

  • "Uformer (et al.) – A general u-shaped transformer for image restoration.” Proceedings of the

  • 2022. (IEEE/CVF conference on computer vision and pattern recognition.)

  • https (.. [2]) – //github.com/ZhendongWang6/Uformer

__init__(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False)[source]#

Inits UFormer.

Parameters:
  • patch_size (int) – Size of the patch. Default: 256.

  • in_channels (int) – Number of input channels. Default: 2.

  • out_channels (Optional[int]) – Number of output channels. Default: None.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Default: 16.

  • bottleneck_num_heads (int) – Default: 2.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • **kwargs – Other keyword arguments to pass to the parent constructor.

no_weight_decay()[source]#
no_weight_decay_keywords()[source]#
extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

forward(x, mask=None)[source]#

Performs forward pass of UFormer.

Parameters:
  • x (Tensor) – Input tensor.

  • mask (Optional[Tensor]) – Mask tensor. Default: None.

Return type:

Tensor

Returns:

Output tensor.

class direct.nn.transformers.uformer.UFormerModel(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source]#

Bases: Module

U-Former model with normalization and padding operations.

Parameters:
  • patch_size (int) – Size of the patch. Default: 256.

  • in_channels (int) – Number of input channels. Default: 2.

  • out_channels (Optional[int]) – Number of output channels. Default: None.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Default: 16.

  • bottleneck_num_heads (int) – Default: 2.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • normalized (bool) – Whether to apply normalization before and denormalization after the forward pass. Default: True.

  • **kwargs – Other keyword arguments to pass to the parent constructor.

__init__(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source]#

Inits UFormer.

Parameters:
  • patch_size (int) – Size of the patch. Default: 256.

  • in_channels (int) – Number of input channels. Default: 2.

  • out_channels (Optional[int]) – Number of output channels. Default: None.

  • embedding_dim (int) – Size of the feature embedding. Default: 32.

  • encoder_depths (tuple[int, ...]) – Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).

  • encoder_num_heads (tuple[int, ...]) – Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).

  • bottleneck_depth (int) – Default: 16.

  • bottleneck_num_heads (int) – Default: 2.

  • win_size (int) – Window size for the attention mechanism. Default: 8.

  • mlp_ratio (float) – Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.

  • qkv_bias (bool) – Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.

  • qk_scale (Optional[float]) – Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.

  • drop_rate (float) – Dropout rate for the token-level dropout layer. Default: 0.0.

  • attn_drop_rate (float) – Dropout rate for the attention score matrix. Default: 0.0.

  • drop_path_rate (float) – Dropout rate for the stochastic depth regularization. Default: 0.1.

  • patch_norm (bool) – Whether to use normalization for the patch embeddings. Default: True.

  • token_projection (AttentionTokenProjectionType) – Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.

  • token_mlp (LeWinTransformerMLPTokenType) – Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.

  • shift_flag (bool) – Whether to use shift operation in the local attention mechanism. Default: True.

  • modulator (bool) – Whether to use a modulator in the attention mechanism. Default: False.

  • cross_modulator (bool) – Whether to use cross-modulation in the attention mechanism. Default: False.

  • normalized (bool) – Whether to apply normalization before and denormalization after the forward pass. Default: True.

  • **kwargs – Other keyword arguments to pass to the parent constructor.

forward(x, mask=None)[source]#

Performs forward pass of UFormerModel.

Parameters:
  • x (Tensor) – Input tensor.

  • mask (Optional[Tensor]) – Mask tensor. Default: None.

Return type:

Tensor

Returns:

Output tensor.

direct.nn.transformers.utils module#

DIRECT module containing utility functions for the transformers models.

direct.nn.transformers.utils.init_weights(m)[source]#

Initializes the weights of the network using a truncated normal distribution.

Parameters:
  • m (Module) – nn.Module

  • initialized. (A module of the network whose weights need to be)

Return type:

None

direct.nn.transformers.utils.norm(x)[source]#

Normalize the input tensor by subtracting the mean and dividing by the standard deviation across each channel and pixel for arbitrary spatial dimensions.

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Input tensor of)

Return type:

tuple[Tensor, Tensor, Tensor]

Returns:

tuple Containing the normalized tensor, mean tensor, and standard deviation tensor.

direct.nn.transformers.utils.pad_to_divisible(x, pad_size)[source]#

Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size.

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Input tensor of)

  • pad_size (tuple[int, ...]) – tuple[int, …]

  • dimension. (Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial)

Return type:

tuple[Tensor, tuple[tuple[int, int], ...]]

Returns:

tuple Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in each spatial dimension.

direct.nn.transformers.utils.pad_to_square(inp, factor)[source]#

Pad a tensor to a square shape with a given factor.

Parameters:
  • inp (Tensor) – torch.Tensor

  • is (The input tensor to pad to square shape. Expected shape)

  • factor (float) – float

  • padded. (The factor to which the input tensor will be)

Return type:

tuple[Tensor, Tensor, tuple[int, int], tuple[int, int]]

Returns:

tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]] A tuple of two tensors, the first is the input tensor padded to a square shape, and the second is the corresponding mask for the padded tensor.

Examples: 1.

>>> x = torch.rand(1, 3, 224, 192)
>>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
>>> padded_x.shape, mask.shape
(torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224]))
  1. >>> x =  torch.rand(3, 13, 2, 234, 180)
    >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0)
    >>> padded_x.shape, wpad, hpad
    (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3))
    
direct.nn.transformers.utils.unnorm(x, mean, std)[source]#

Denormalize the input tensor by multiplying by the standard deviation and adding the mean for arbitrary spatial dimensions.

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Input tensor of)

  • mean (Tensor) – torch.Tensor

  • normalization. (Standard deviation tensor obtained during)

  • std (Tensor) – torch.Tensor

  • normalization.

Return type:

Tensor

Returns:

torch.Tensor Tensor with the same shape as the original input tensor, but denormalized.

direct.nn.transformers.utils.unpad_to_original(x, *pads)[source]#

Remove the padding added to the input tensor.

Parameters:
  • x (Tensor) – torch.Tensor

  • dimensions. (Input tensor with padded spatial)

  • pads (tuple[int, int]) – tuple[int, int]

  • of (A tuple)

Return type:

Tensor

Returns:

torch.Tensor Tensor with the padding removed, matching the shape of the original input tensor before padding.

class direct.nn.transformers.utils.DropoutPath(drop_prob=0.0, scale_by_keep=True)[source]#

Bases: Module

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

__init__(drop_prob=0.0, scale_by_keep=True)[source]#

Inits DropoutPath.

Parameters:
  • drop_prob (float) – float

  • Default (the activations.) – 0.0.

  • scale_by_keep (bool) – bool

  • / (Whether to scale the remaining activations by 1)

  • Default – True.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

extra_repr()[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

direct.nn.transformers.vit module#

DIRECT Vision Transformer module.

Implementation of Vision Transformer model [1, 2]_ in PyTorch.

Code borrowed from [3] which uses code from timm [4].

References

class direct.nn.transformers.vit.VisionTransformer2D(average_img_size=320, patch_size=16, in_channels=COMPLEX_SIZE, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source]#

Bases: VisionTransformer

Vision Transformer model for 2D data.

Parameters:
  • average_img_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size). Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size). Default: 16.

  • in_channels (int) – Number of input channels. Default: COMPLEX_SIZE.

  • out_channels (int) – Number of output channels. If None, this will be set to in_channels. Default: None.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (bool) – Whether to use GPSA layer. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(average_img_size=320, patch_size=16, in_channels=COMPLEX_SIZE, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source]#

Inits VisionTransformer2D.

Parameters:
  • average_img_size (int | tuple[int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size). Default: 320.

  • patch_size (int | tuple[int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size). Default: 16.

  • in_channels (int) – Number of input channels. Default: COMPLEX_SIZE.

  • out_channels (int) – Number of output channels. If None, this will be set to in_channels. Default: None.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (bool) – Whether to use GPSA layer. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

seq2img(x, img_size)[source]#

Converts the sequence patches tensor to an image tensor.

Parameters:
  • x (Tensor) – The sequence tensor.

  • img_size (tuple[int, ...]) – The size of the image tensor.

Return type:

Tensor

Returns:

The image tensor.

class direct.nn.transformers.vit.VisionTransformer3D(average_img_size=320, patch_size=16, in_channels=COMPLEX_SIZE, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source]#

Bases: VisionTransformer

Vision Transformer model for 3D data.

Parameters:
  • average_img_size (int | tuple[int, int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size, average_img_size). Default: 320.

  • patch_size (int | tuple[int, int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

  • in_channels (int) – Number of input channels. Default: COMPLEX_SIZE.

  • out_channels (int) – Number of output channels. If None, this will be set to in_channels. Default: None.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (bool) – Whether to use GPSA layer. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

__init__(average_img_size=320, patch_size=16, in_channels=COMPLEX_SIZE, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source]#

Inits VisionTransformer3D.

Parameters:
  • average_img_size (int | tuple[int, int, int]) – The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size, average_img_size). Default: 320.

  • patch_size (int | tuple[int, int, int]) – The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.

  • in_channels (int) – Number of input channels. Default: COMPLEX_SIZE.

  • out_channels (int) – Number of output channels. If None, this will be set to in_channels. Default: None.

  • embedding_dim (int) – Dimension of the output embedding.

  • depth (int) – Number of transformer blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (float) – The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.

  • qkv_bias (bool) – Whether to add bias to the query, key, and value projections. Default: False.

  • qk_scale (float) – The scale factor for the query-key dot product. Default: None.

  • drop_rate (float) – The dropout probability for all dropout layers except dropout_path. Default: 0.0.

  • attn_drop_rate (float) – The dropout probability for the attention layer. Default: 0.0.

  • dropout_path_rate (float) – The dropout probability for the dropout path. Default: 0.0.

  • use_gpsa (bool) – Whether to use GPSA layer. Default: True.

  • locality_strength (float) – The strength of the locality assumption in initialization. Default: 1.0.

  • use_pos_embedding (bool) – Whether to use positional embeddings. Default: True.

  • normalized (bool) – Whether to normalize the input tensor. Default: True.

seq2img(x, img_size)[source]#

Converts the sequence of 3D patches to a 3D image tensor.

Parameters:
  • x (Tensor) – The sequence tensor, where each entry corresponds to a flattened 3D patch.

  • img_size (tuple[int, ...]) – The size of the 3D image tensor (depth, height, width).

Return type:

Tensor

Returns:

The reconstructed 3D image tensor.

Module contents#

DIRECT transformers models.