direct.nn.transformers package#
Submodules#
direct.nn.transformers.config module#
- class direct.nn.transformers.config.ImageDomainMRIUFormerConfig(model_name: str = '???', engine_name: Optional[str] = None, patch_size: 'int' = 256, embedding_dim: 'int' = 32, encoder_depths: 'tuple[int, ...]' = (2, 2, 2, 2), encoder_num_heads: 'tuple[int, ...]' = (1, 2, 4, 8), bottleneck_depth: 'int' = 2, bottleneck_num_heads: 'int' = 16, win_size: 'int' = 8, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = True, qk_scale: 'Optional[float]' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, drop_path_rate: 'float' = 0.1, patch_norm: 'bool' = True, token_projection: 'AttentionTokenProjectionType' = <AttentionTokenProjectionType.LINEAR: 'linear'>, token_mlp: 'LeWinTransformerMLPTokenType' = <LeWinTransformerMLPTokenType.LEFF: 'leff'>, shift_flag: 'bool' = True, modulator: 'bool' = False, cross_modulator: 'bool' = False, normalized: 'bool' = True)[source][source]#
Bases:
ModelConfig
-
attn_drop_rate:
float
= 0.0#
-
bottleneck_depth:
int
= 2#
-
bottleneck_num_heads:
int
= 16#
-
cross_modulator:
bool
= False#
-
drop_path_rate:
float
= 0.1#
-
drop_rate:
float
= 0.0#
-
embedding_dim:
int
= 32#
-
encoder_depths:
tuple
[int
,...
] = (2, 2, 2, 2)#
-
encoder_num_heads:
tuple
[int
,...
] = (1, 2, 4, 8)#
-
mlp_ratio:
float
= 4.0#
-
modulator:
bool
= False#
-
normalized:
bool
= True#
-
patch_norm:
bool
= True#
-
patch_size:
int
= 256#
-
qk_scale:
Optional
[float
] = None#
-
qkv_bias:
bool
= True#
-
shift_flag:
bool
= True#
-
token_mlp:
LeWinTransformerMLPTokenType
= 'leff'#
-
token_projection:
AttentionTokenProjectionType
= 'linear'#
-
win_size:
int
= 8#
-
attn_drop_rate:
- class direct.nn.transformers.config.ImageDomainMRIViT2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320), patch_size: 'tuple[int, int]' = (16, 16))[source][source]#
Bases:
MRIViTConfig
-
average_size:
tuple
[int
,int
] = (320, 320)#
-
patch_size:
tuple
[int
,int
] = (16, 16)#
-
average_size:
- class direct.nn.transformers.config.ImageDomainMRIViT3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320, 320), patch_size: 'tuple[int, int]' = (16, 16, 16))[source][source]#
Bases:
MRIViTConfig
-
average_size:
tuple
[int
,int
] = (320, 320, 320)#
-
patch_size:
tuple
[int
,int
] = (16, 16, 16)#
-
average_size:
- class direct.nn.transformers.config.KSpaceDomainMRIViT2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320), patch_size: 'tuple[int, int]' = (16, 16), compute_per_coil: 'bool' = True)[source][source]#
Bases:
MRIViTConfig
-
average_size:
tuple
[int
,int
] = (320, 320)#
-
compute_per_coil:
bool
= True#
-
patch_size:
tuple
[int
,int
] = (16, 16)#
-
average_size:
- class direct.nn.transformers.config.KSpaceDomainMRIViT3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, average_size: 'tuple[int, int]' = (320, 320, 320), patch_size: 'tuple[int, int]' = (16, 16, 16), compute_per_coil: 'bool' = True)[source][source]#
Bases:
MRIViTConfig
-
average_size:
tuple
[int
,int
] = (320, 320, 320)#
-
compute_per_coil:
bool
= True#
-
patch_size:
tuple
[int
,int
] = (16, 16, 16)#
-
average_size:
- class direct.nn.transformers.config.MRIViTConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True)[source][source]#
Bases:
ModelConfig
-
attn_drop_rate:
float
= 0.0#
-
depth:
int
= 8#
-
drop_rate:
float
= 0.0#
-
dropout_path_rate:
float
= 0.0#
-
embedding_dim:
int
= 64#
-
locality_strength:
float
= 1.0#
-
mlp_ratio:
float
= 4.0#
-
normalized:
bool
= True#
-
num_heads:
int
= 9#
-
qk_scale:
float
= None#
-
qkv_bias:
bool
= False#
-
use_gpsa:
bool
= True#
-
use_pos_embedding:
bool
= True#
-
attn_drop_rate:
- class direct.nn.transformers.config.UFormerModelConfig(model_name: str = '???', engine_name: Optional[str] = None, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, patch_size: 'int' = 256, embedding_dim: 'int' = 32, encoder_depths: 'tuple[int, ...]' = (2, 2, 2, 2), encoder_num_heads: 'tuple[int, ...]' = (1, 2, 4, 8), bottleneck_depth: 'int' = 2, bottleneck_num_heads: 'int' = 16, win_size: 'int' = 8, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = True, qk_scale: 'Optional[float]' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, drop_path_rate: 'float' = 0.1, patch_norm: 'bool' = True, token_projection: 'AttentionTokenProjectionType' = <AttentionTokenProjectionType.LINEAR: 'linear'>, token_mlp: 'LeWinTransformerMLPTokenType' = <LeWinTransformerMLPTokenType.LEFF: 'leff'>, shift_flag: 'bool' = True, modulator: 'bool' = False, cross_modulator: 'bool' = False, normalized: 'bool' = True)[source][source]#
Bases:
ModelConfig
-
attn_drop_rate:
float
= 0.0#
-
bottleneck_depth:
int
= 2#
-
bottleneck_num_heads:
int
= 16#
-
cross_modulator:
bool
= False#
-
drop_path_rate:
float
= 0.1#
-
drop_rate:
float
= 0.0#
-
embedding_dim:
int
= 32#
-
encoder_depths:
tuple
[int
,...
] = (2, 2, 2, 2)#
-
encoder_num_heads:
tuple
[int
,...
] = (1, 2, 4, 8)#
-
in_channels:
int
= 2#
-
mlp_ratio:
float
= 4.0#
-
modulator:
bool
= False#
-
normalized:
bool
= True#
-
out_channels:
Optional
[int
] = None#
-
patch_norm:
bool
= True#
-
patch_size:
int
= 256#
-
qk_scale:
Optional
[float
] = None#
-
qkv_bias:
bool
= True#
-
shift_flag:
bool
= True#
-
token_mlp:
LeWinTransformerMLPTokenType
= 'leff'#
-
token_projection:
AttentionTokenProjectionType
= 'linear'#
-
win_size:
int
= 8#
-
attn_drop_rate:
- class direct.nn.transformers.config.VisionTransformer2DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, average_img_size: 'tuple[int, int]' = '???', patch_size: 'tuple[int, int]' = (16, 16))[source][source]#
Bases:
MRIViTConfig
-
average_img_size:
tuple
[int
,int
] = '???'#
-
in_channels:
int
= 2#
-
out_channels:
Optional
[int
] = None#
-
patch_size:
tuple
[int
,int
] = (16, 16)#
-
average_img_size:
- class direct.nn.transformers.config.VisionTransformer3DConfig(model_name: str = '???', engine_name: str | None = None, embedding_dim: 'int' = 64, depth: 'int' = 8, num_heads: 'int' = 9, mlp_ratio: 'float' = 4.0, qkv_bias: 'bool' = False, qk_scale: 'float' = None, drop_rate: 'float' = 0.0, attn_drop_rate: 'float' = 0.0, dropout_path_rate: 'float' = 0.0, use_gpsa: 'bool' = True, locality_strength: 'float' = 1.0, use_pos_embedding: 'bool' = True, normalized: 'bool' = True, in_channels: 'int' = 2, out_channels: 'Optional[int]' = None, average_img_size: 'tuple[int, int, int]' = '???', patch_size: 'tuple[int, int, int]' = (16, 16, 16))[source][source]#
Bases:
MRIViTConfig
-
average_img_size:
tuple
[int
,int
,int
] = '???'#
-
in_channels:
int
= 2#
-
out_channels:
Optional
[int
] = None#
-
patch_size:
tuple
[int
,int
,int
] = (16, 16, 16)#
-
average_img_size:
direct.nn.transformers.transformers module#
DIRECT Vision Transformer models for MRI reconstruction.
- class direct.nn.transformers.transformers.ImageDomainMRIUFormer(forward_operator, backward_operator, patch_size=256, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True, **kwargs)[source][source]#
Bases:
Module
U-Former model for MRI reconstruction in the image domain.
- Parameters:
- forward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Forward operator function.
- backward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Backward operator function.
- patch_sizeint
Size of the patch. Default: 256.
- in_channelsint
Number of input channels. Default: 2.
- out_channelsint, optional
Number of output channels. Default: None.
- embedding_dimint
Size of the feature embedding. Default: 32.
- encoder_depthstuple
Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).
- encoder_num_headstuple
Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).
- bottleneck_depthint
Default: 16.
- bottleneck_num_headsint
Default: 2.
- win_sizeint
Window size for the attention mechanism. Default: 8.
- mlp_ratiofloat
Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.
- qkv_biasbool
Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.
- qk_scalefloat
Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.
- drop_ratefloat
Dropout rate for the token-level dropout layer. Default: 0.0.
- attn_drop_ratefloat
Dropout rate for the attention score matrix. Default: 0.0.
- drop_path_ratefloat
Dropout rate for the stochastic depth regularization. Default: 0.1.
- patch_normbool
Whether to use normalization for the patch embeddings. Default: True.
- token_projectionAttentionTokenProjectionType
Type of token projection. Must be one of [“linear”, “conv”]. Default: AttentionTokenProjectionType.LINEAR.
- token_mlpLeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of [“leff”, “mlp”, “ffn”]. Default: LeWinTransformerMLPTokenType.LEFF.
- shift_flagbool
Whether to use shift operation in the local attention mechanism. Default: True.
- modulatorbool
Whether to use a modulator in the attention mechanism. Default: False.
- cross_modulatorbool
Whether to use cross-modulation in the attention mechanism. Default: False.
- normalizedbool
Whether to apply normalization before and denormalization after the forward pass. Default: True.
- forward(masked_kspace, sensitivity_map)[source][source]#
Forward pass of
ImageDomainMRIUFormer
.- Return type:
Tensor
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2)
- Returns:
- outtorch.Tensor
The output tensor of shape (N, height, width, complex=2).
- class direct.nn.transformers.transformers.ImageDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source][source]#
Bases:
Module
Vision Transformer for MRI reconstruction in 2D.
- Parameters:
- forward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Forward operator function.
- backward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Backward operator function.
- average_sizeint or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.
- patch_sizeint or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsabool, optional
Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- forward(masked_kspace, sensitivity_map)[source][source]#
Forward pass of
ImageDomainMRIViT2D
.- Return type:
Tensor
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2)
- Returns:
- outtorch.Tensor
The output tensor of shape (N, height, width, complex=2).
- class direct.nn.transformers.transformers.ImageDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, **kwargs)[source][source]#
Bases:
Module
Vision Transformer for MRI reconstruction in 3D.
- Parameters:
- forward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Forward operator function.
- backward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Backward operator function.
- average_sizeint or tuple[int, int, int]
The average size of the input image. If an int is provided, this will be defined as (average_size, average_size, average_size). Default: 320.
- patch_sizeint or tuple[int, int, int]
The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsabool, optional
Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- forward(masked_kspace, sensitivity_map)[source][source]#
Forward pass of
ImageDomainMRIViT3D
.- Return type:
Tensor
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, slice/time, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)
- Returns:
- outtorch.Tensor
The output tensor of shape (N, slice/time, height, width, complex=2).
- class direct.nn.transformers.transformers.KSpaceDomainMRIViT2D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source][source]#
Bases:
Module
Vision Transformer for MRI reconstruction in 2D in k-space.
- Parameters:
- forward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Forward operator function.
- backward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Backward operator function.
- average_sizeint or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.
- patch_sizeint or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsabool, optional
Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#
Forward pass of
KSpaceDomainMRIViT2D
.- Return type:
Tensor
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2)
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- Returns:
- outtorch.Tensor
The output tensor of shape (N, height, width, complex=2).
- class direct.nn.transformers.transformers.KSpaceDomainMRIViT3D(forward_operator, backward_operator, average_size=320, patch_size=16, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=(-1, -1), locality_strength=1.0, use_pos_embedding=True, normalized=True, compute_per_coil=True, **kwargs)[source][source]#
Bases:
Module
Vision Transformer for MRI reconstruction in 3D in k-space.
- Parameters:
- forward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Forward operator function.
- backward_operatorCallable[[tuple[Any, …]], torch.Tensor]
Backward operator function.
- average_sizeint or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_size, average_size) for 2D and (average_size, average_size, average_size) for 3D. Default: 320.
- patch_sizeint or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsabool, optional
Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#
Forward pass of
KSpaceDomainMRIViT3D
.- Return type:
Tensor
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, slice/time, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2)
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1).
- Returns:
- outtorch.Tensor
The output tensor of shape (N, slice/time height, width, complex=2).
direct.nn.transformers.transformers_engine module#
DIRECT MRI transformer-based model engines.
- class direct.nn.transformers.transformers_engine.ImageDomainMRIUFormerEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
ImageDomainMRIViTEngine
MRI U-Former Model Engine for Image Domain.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.transformers.transformers_engine.ImageDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
ImageDomainMRIViTEngine
MRI ViT Model Engine for Image Domain 2D.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.transformers.transformers_engine.ImageDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
ImageDomainMRIViTEngine
MRI ViT Model Engine for Image Domain 3D.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.transformers.transformers_engine.ImageDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
MRI ViT Model Engine for Image Domain.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- forward_function(data)[source][source]#
Forward function for
ImageDomainMRIViTEngine
.- Parameters:
- datadict[str, Any]
Input data.
- Returns:
- tuple[torch.Tensor, torch.Tensor]
Output image and output k-space.
- Return type:
tuple
[Tensor
,Tensor
]
- class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT2DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
KSpaceDomainMRIViTEngine
MRI ViT Model Engine for K-Space Domain 2D.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViT3DEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
KSpaceDomainMRIViTEngine
MRI ViT Model Engine for K-Space Domain 3D.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- class direct.nn.transformers.transformers_engine.KSpaceDomainMRIViTEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine
MRI ViT Model Engine for K-Space Domain.
- Parameters:
- cfg: BaseConfig
Configuration file.
- model: nn.Module
Model.
- device: str
Device. Can be “cuda:{idx}” or “cpu”.
- forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The forward operator. Default: None.
- backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional
The backward operator. Default: None.
- mixed_precision: bool
Use mixed precision. Default: False.
- **models: nn.Module
Additional models.
- forward_function(data)[source][source]#
Forward function for
KSpaceDomainMRIViTEngine
.- Parameters:
- datadict[str, Any]
Input data.
- Returns:
- tuple[torch.Tensor, torch.Tensor]
Output image and output k-space.
- Return type:
tuple
[Tensor
,Tensor
]
direct.nn.transformers.uformer module#
U-Former model [1] implementation.
Adapted from [2].
References#
Wang, Zhendong, et al. “Uformer: A general u-shaped transformer for image restoration.” Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
- class direct.nn.transformers.uformer.AttentionTokenProjectionType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum
- CONV = 'conv'#
- LINEAR = 'linear'#
- class direct.nn.transformers.uformer.LeWinTransformerMLPTokenType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source][source]#
Bases:
DirectEnum
- FFN = 'ffn'#
- LEFF = 'leff'#
- MLP = 'mlp'#
- class direct.nn.transformers.uformer.UFormer(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False)[source][source]#
Bases:
Module
U-Former model based on [1], code originally implemented in [2].
- Parameters:
- patch_sizeint
Size of the patch. Default: 256.
- in_channelsint
Number of input channels. Default: 2.
- out_channelsint, optional
Number of output channels. Default: None.
- embedding_dimint
Size of the feature embedding. Default: 32.
- encoder_depthstuple
Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).
- encoder_num_headstuple
Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).
- bottleneck_depthint
Default: 16.
- bottleneck_num_headsint
Default: 2.
- win_sizeint
Window size for the attention mechanism. Default: 8.
- mlp_ratiofloat
Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.
- qkv_biasbool
Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.
- qk_scalefloat
Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.
- drop_ratefloat
Dropout rate for the token-level dropout layer. Default: 0.0.
- attn_drop_ratefloat
Dropout rate for the attention score matrix. Default: 0.0.
- drop_path_ratefloat
Dropout rate for the stochastic depth regularization. Default: 0.1.
- patch_normbool
Whether to use normalization for the patch embeddings. Default: True.
- token_projectionAttentionTokenProjectionType
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
- token_mlpLeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
- shift_flagbool
Whether to use shift operation in the local attention mechanism. Default: True.
- modulatorbool
Whether to use a modulator in the attention mechanism. Default: False.
- cross_modulatorbool
Whether to use cross-modulation in the attention mechanism. Default: False.
- **kwargs: Other keyword arguments to pass to the parent constructor.
References
[1]Wang, Zhendong, et al. “Uformer: A general u-shaped transformer for image restoration.” Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
- extra_repr()[source][source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- Return type:
str
- class direct.nn.transformers.uformer.UFormerModel(patch_size=256, in_channels=2, out_channels=None, embedding_dim=32, encoder_depths=(2, 2, 2, 2), encoder_num_heads=(1, 2, 4, 8), bottleneck_depth=2, bottleneck_num_heads=16, win_size=8, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, patch_norm=True, token_projection=AttentionTokenProjectionType.LINEAR, token_mlp=LeWinTransformerMLPTokenType.LEFF, shift_flag=True, modulator=False, cross_modulator=False, normalized=True)[source][source]#
Bases:
Module
U-Former model with normalization and padding operations.
- Parameters:
- patch_sizeint
Size of the patch. Default: 256.
- in_channelsint
Number of input channels. Default: 2.
- out_channelsint, optional
Number of output channels. Default: None.
- embedding_dimint
Size of the feature embedding. Default: 32.
- encoder_depthstuple
Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2).
- encoder_num_headstuple
Number of attention heads for each layer of the encoder of the U-former, from top to bottom. Default: (1, 2, 4, 8).
- bottleneck_depthint
Default: 16.
- bottleneck_num_headsint
Default: 2.
- win_sizeint
Window size for the attention mechanism. Default: 8.
- mlp_ratiofloat
Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0.
- qkv_biasbool
Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True.
- qk_scalefloat
Scale factor for the query and key projection vectors. If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None.
- drop_ratefloat
Dropout rate for the token-level dropout layer. Default: 0.0.
- attn_drop_ratefloat
Dropout rate for the attention score matrix. Default: 0.0.
- drop_path_ratefloat
Dropout rate for the stochastic depth regularization. Default: 0.1.
- patch_normbool
Whether to use normalization for the patch embeddings. Default: True.
- token_projectionAttentionTokenProjectionType
Type of token projection. Must be one of AttentionTokenProjectionType.LINEAR or AttentionTokenProjectionType.CONV. Default: AttentionTokenProjectionType.LINEAR.
- token_mlpLeWinTransformerMLPTokenType
Type of token-level MLP. Must be one of LeWinTransformerMLPTokenType.LEFF or LeWinTransformerMLPTokenType.MLP. Default: LeWinTransformerMLPTokenType.LEFF.
- shift_flagbool
Whether to use shift operation in the local attention mechanism. Default: True.
- modulatorbool
Whether to use a modulator in the attention mechanism. Default: False.
- cross_modulatorbool
Whether to use cross-modulation in the attention mechanism. Default: False.
- normalizedbool
Whether to apply normalization before and denormalization after the forward pass. Default: True.
- **kwargs: Other keyword arguments to pass to the parent constructor.
direct.nn.transformers.utils module#
DIRECT module containing utility functions for the transformers models.
- class direct.nn.transformers.utils.DropoutPath(drop_prob=0.0, scale_by_keep=True)[source][source]#
Bases:
Module
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- extra_repr()[source][source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x)[source][source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- direct.nn.transformers.utils.init_weights(m)[source][source]#
Initializes the weights of the network using a truncated normal distribution.
- Parameters:
- mnn.Module
A module of the network whose weights need to be initialized.
- Return type:
None
- direct.nn.transformers.utils.norm(x)[source][source]#
Normalize the input tensor by subtracting the mean and dividing by the standard deviation across each channel and pixel for arbitrary spatial dimensions.
- Parameters:
- xtorch.Tensor
Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number (e.g., 2D, 3D, etc.).
- Returns:
- tuple
Containing the normalized tensor, mean tensor, and standard deviation tensor.
- Return type:
tuple
[Tensor
,Tensor
,Tensor
]
- direct.nn.transformers.utils.pad_to_divisible(x, pad_size)[source][source]#
Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size.
- Parameters:
- xtorch.Tensor
Input tensor of shape (*, spatial_1, spatial_2, …, spatial_N), where spatial dimensions can vary in number.
- pad_sizetuple[int, …]
Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial dimension.
- Returns:
- tuple
Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in each spatial dimension.
- Return type:
tuple
[Tensor
,tuple
[tuple
[int
,int
],...
]]
- direct.nn.transformers.utils.pad_to_square(inp, factor)[source][source]#
Pad a tensor to a square shape with a given factor.
- Parameters:
- inptorch.Tensor
The input tensor to pad to square shape. Expected shape is (*, height, width).
- factorfloat
The factor to which the input tensor will be padded.
- Returns:
- tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]
A tuple of two tensors, the first is the input tensor padded to a square shape, and the second is the corresponding mask for the padded tensor.
- Return type:
tuple
[Tensor
,Tensor
,tuple
[int
,int
],tuple
[int
,int
]]
Examples
>>> x = torch.rand(1, 3, 224, 192) >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) >>> padded_x.shape, mask.shape (torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224]))
>>> x = torch.rand(3, 13, 2, 234, 180) >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) >>> padded_x.shape, wpad, hpad (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3))
- direct.nn.transformers.utils.unnorm(x, mean, std)[source][source]#
Denormalize the input tensor by multiplying by the standard deviation and adding the mean for arbitrary spatial dimensions.
- Parameters:
- xtorch.Tensor
Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number.
- meantorch.Tensor
Mean tensor obtained during normalization.
- stdtorch.Tensor
Standard deviation tensor obtained during normalization.
- Returns:
- torch.Tensor
Tensor with the same shape as the original input tensor, but denormalized.
- Return type:
Tensor
- direct.nn.transformers.utils.unpad_to_original(x, *pads)[source][source]#
Remove the padding added to the input tensor.
- Parameters:
- xtorch.Tensor
Input tensor with padded spatial dimensions.
- padstuple[int, int]
A tuple of (pad_before, pad_after) for each spatial dimension.
- Returns:
- torch.Tensor
Tensor with the padding removed, matching the shape of the original input tensor before padding.
- Return type:
Tensor
direct.nn.transformers.vit module#
DIRECT Vision Transformer module.
Implementation of Vision Transformer model [1, 2]_ in PyTorch.
Code borrowed from [3] which uses code from timm [4].
References#
Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, http://arxiv.org/abs/2010.11929, (2021).
Steiner, A., Kolesnikov, A., Zhai, X., Wightman, R., Uszkoreit, J., Beyer, L.: How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers, http://arxiv.org/abs/2106.10270, (2022).
- class direct.nn.transformers.vit.VisionTransformer2D(average_img_size=320, patch_size=16, in_channels=2, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source][source]#
Bases:
VisionTransformer
Vision Transformer model for 2D data.
- Parameters:
- average_img_sizeint or tuple[int, int]
The average size of the input image. If an int is provided, this will be determined by the dimensionality, i.e., (average_img_size, average_img_size) for 2D and (average_img_size, average_img_size, average_img_size) for 3D. Default: 320.
- patch_sizeint or tuple[int, int]
The size of the patch. If an int is provided, this will be determined by the dimensionality, i.e., (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16.
- in_channelsint
Number of input channels. Default: COMPLEX_SIZE.
- out_channelsint or None
Number of output channels. If None, this will be set to in_channels. Default: None.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsa: bool
Whether to use GPSA layer. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- class direct.nn.transformers.vit.VisionTransformer3D(average_img_size=320, patch_size=16, in_channels=2, out_channels=None, embedding_dim=64, depth=8, num_heads=9, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_gpsa=True, locality_strength=1.0, use_pos_embedding=True, normalized=True)[source][source]#
Bases:
VisionTransformer
Vision Transformer model for 3D data.
- Parameters:
- average_img_sizeint or tuple[int, int, int]
The average size of the input image. If an int is provided, this will be defined as (average_img_size, average_img_size, average_img_size). Default: 320.
- patch_sizeint or tuple[int, int, int]
The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). Default: 16.
- in_channelsint
Number of input channels. Default: COMPLEX_SIZE.
- out_channelsint or None
Number of output channels. If None, this will be set to in_channels. Default: None.
- embedding_dimint
Dimension of the output embedding.
- depthint
Number of transformer blocks.
- num_headsint
Number of attention heads.
- mlp_ratiofloat
The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0.
- qkv_biasbool
Whether to add bias to the query, key, and value projections. Default: False.
- qk_scalefloat
The scale factor for the query-key dot product. Default: None.
- drop_ratefloat
The dropout probability for all dropout layers except dropout_path. Default: 0.0.
- attn_drop_ratefloat
The dropout probability for the attention layer. Default: 0.0.
- dropout_path_ratefloat
The dropout probability for the dropout path. Default: 0.0.
- use_gpsa: bool
Whether to use GPSA layer. Default: True.
- locality_strengthfloat
The strength of the locality assumption in initialization. Default: 1.0.
- use_pos_embeddingbool
Whether to use positional embeddings. Default: True.
- normalizedbool
Whether to normalize the input tensor. Default: True.
- seq2img(x, img_size)[source][source]#
Converts the sequence of 3D patches to a 3D image tensor.
- Parameters:
- xtorch.Tensor
The sequence tensor, where each entry corresponds to a flattened 3D patch.
- img_sizetuple of ints
The size of the 3D image tensor (depth, height, width).
- Returns:
- torch.Tensor
The reconstructed 3D image tensor.
- Return type:
Tensor
Module contents#
DIRECT transformers models.