direct.ssl package#
Submodules#
direct.ssl.mask_fillers module#
SSL Mask Fillers.
This module contains functions for splitting binary masks into (disjoint) subsets to be used for self-supervised learning MRI reconstruction tasks.
- direct.ssl.mask_fillers.gaussian_fill(nonzero_mask_count, nrow, ncol, center_x, center_y, std_scale, mask, output_mask, seed)[source][source]#
Generates a binary mask filled with randomly sampled positions following a 2D Gaussian distribution.
Makes a call to the cython function _gaussian_fill.
- Parameters:
- nonzero_mask_countint
Number of non-zero entries in the output mask.
- nrowint
Number of rows of the output mask.
- ncolint
Number of columns of the output mask.
- center_xint
X coordinate of the center of the Gaussian distribution.
- center_yint
Y coordinate of the center of the Gaussian distribution.
- std_scalefloat
Scaling factor for the standard deviation of the Gaussian distribution. The standard deviation of the Gaussian distribution will be (nrow-1)/std_scale and (ncol-1)/std_scale along the X and Y axes, respectively.
- masknp.ndarray
A binary integer 2D array representing the input mask.
- output_masknp.ndarray
A binary integer 2D array representing the output mask.
- seedint
Seed for the random number generator.
- Returns:
- np.ndarray
A 2D array representing the output mask filled with randomly sampled positions following a 2D Gaussian distribution.
- Return type:
ndarray
- direct.ssl.mask_fillers.uniform_fill(nonzero_mask_count, nrow, ncol, mask, rng)[source][source]#
Fills a binary torch.Tensor mask with the specified number of ones in a uniform random manner.
- Parameters:
- nonzero_mask_countint
The number of 1s to place in the mask.
- nrowint
The number of rows in the mask.
- ncolint
The number of columns in the mask.
- masktorch.Tensor
A binary mask with zeros and ones.
- rngnp.random.RandomState
A NumPy random state object for reproducibility.
- Returns:
- torch.Tensor
A binary mask with the specified number of 1s placed in a uniform random manner.
- Return type:
Tensor
direct.ssl.ssl module#
Direct SSL mask splitters module.
This module contains classes for splitting masks for self-supervised learning tasks for MRI reconstruction.
For example, the GaussianMaskSplitterModule
splits the input mask into two disjoint masks using a Gaussian
split scheme, while the UniformMaskSplitterModule
splits the input mask into two disjoint masks using a
uniform split scheme. The HalfMaskSplitterModule
splits the input mask into two disjoint masks in a half
line direction.
- class direct.ssl.ssl.GaussianMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE, std_scale=3.0)[source][source]#
Bases:
MaskSplitter
Uses Gaussian splitting method to split the input mask into two disjoint masks.
- Parameters:
- ratiofloat, list[float] or tuple[float, …], optional
Split ratio such that \(ratio \approx \frac{|A|}{|B|}\). Default: 0.5.
- acs_regionlist[int] or tuple[int, int], optional
Size of ACS region to include in training (input) mask. Default: (0, 0).
- keep_acsbool, optional
If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.
- use_seedbool, optional
If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.
- kspace_keystr, optional
K-space key. Default “masked_kspace”.
- std_scalefloat, optional
This is used to calculate the standard deviation of the Gaussian distribution. Default: 3.0.
- split_method(sampling_mask, acs_mask, seed)[source][source]#
Splits the sampling_mask into two disjoint masks based on gaussian split method.
- Parameters:
- sampling_masktorch.Tensor
The input mask tensor to be split.
- acs_masktorch.Tensor or None
The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.
- seedint, iterable of ints or None
Seed to generate split.
- Returns:
- tuple[torch.Tensor, torch.Tensor]:
The two disjoint masks, input_mask and target_mask.
- Return type:
tuple
[Tensor
,Tensor
]
-
training:
bool
#
- class direct.ssl.ssl.HalfMaskSplitterModule(acs_region=(0, 0), keep_acs=False, use_seed=True, direction=HalfSplitType.VERTICAL, kspace_key=KspaceKey.MASKED_KSPACE)[source][source]#
Bases:
MaskSplitter
Splits the input mask into two disjoint masks in a half line direction.
- Parameters:
- acs_region: list[int] or tuple[int, int], optional
Size of ACS region to include in training (input) mask. Default: (0, 0).
- keep_acs: bool, optional
If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.
- use_seed: bool, optional
If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.
- kspace_key: str, optional
K-space key. Default “masked_kspace”.
- split_method(sampling_mask, acs_mask, seed)[source][source]#
Splits the sampling_mask into two disjoint masks based on gaussian split method.
- Parameters:
- sampling_masktorch.Tensor
The input mask tensor to be split.
- acs_masktorch.Tensor or None
The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.
- seedint, iterable of ints or None
Seed to generate split.
- Returns:
- tuple[torch.Tensor, torch.Tensor]:
The two disjoint masks, input_mask and target_mask.
- Return type:
tuple
[Tensor
,Tensor
]
-
training:
bool
#
- class direct.ssl.ssl.HalfSplitType(value)[source][source]#
Bases:
DirectEnum
SSL half mask splitter types.
These are used to define the type of half mask splitting.
- DIAGONAL_LEFT = 'diagonal_left'#
- DIAGONAL_RIGHT = 'diagonal_right'#
- HORIZONTAL = 'horizontal'#
- VERTICAL = 'vertical'#
- class direct.ssl.ssl.MaskSplitterType(value)[source][source]#
Bases:
DirectEnum
SSL mask splitter types.
These are used to define the type of mask splitting.
- GAUSSIAN = 'gaussian'#
- HALF = 'half'#
- UNIFORM = 'uniform'#
- class direct.ssl.ssl.SSLTransformMaskPrefixes(value)[source][source]#
Bases:
DirectEnum
SSL Transform mask prefixes.
These are used to prefix the input (\(\Theta) and target (:math:\)Lambda`) k-spaces/masks in the sample.
- INPUT_ = 'input_'#
- TARGET_ = 'target_'#
- class direct.ssl.ssl.UniformMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE)[source][source]#
Bases:
MaskSplitter
Uses Uniform splitting method to split the input mask into two disjoint masks.
- Parameters:
- ratiofloat, list[float] or tuple[float, …], optional
Split ratio such that \(ratio \approx \frac{|A|}{|B|}\). Default: 0.5.
- acs_regionlist[int] or tuple[int, int], optional
Size of ACS region to include in training (input) mask. Default: (0, 0).
- keep_acsbool, optional
If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.
- use_seedbool, optional
If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.
- kspace_keystr, optional
K-space key. Default “masked_kspace”.
- split_method(sampling_mask, acs_mask, seed)[source][source]#
Splits the sampling_mask into two disjoint masks based on the uniform split method.
- Parameters:
- sampling_masktorch.Tensor
The input mask tensor to be split.
- acs_masktorch.Tensor or None
The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.
- seedint, iterable of ints or None
Seed to generate split.
- Returns:
- tuple[torch.Tensor, torch.Tensor]:
The two disjoint masks, input_mask and target_mask.
- Return type:
tuple
[Tensor
,Tensor
]
-
training:
bool
#
Module contents#
Direct SSL module.
Includes functions for self-supervised learning for MRI reconstruction.