direct.ssl package#
Submodules#
direct.ssl.mask_fillers module#
SSL Mask Fillers.
This module contains functions for splitting binary masks into (disjoint) subsets to be used for self-supervised learning MRI reconstruction tasks.
- direct.ssl.mask_fillers.gaussian_fill(nonzero_mask_count, nrow, ncol, center_x, center_y, std_scale, mask, output_mask, seed)[source]#
Generates a binary mask filled with randomly sampled positions following a 2D Gaussian distribution.
Makes a call to the cython function
_gaussian_fill.- Parameters:
nonzero_mask_count (
int) – Number of non-zero entries in the output mask.nrow (
int) – Number of rows of the output mask.ncol (
int) – Number of columns of the output mask.center_x (
int) – X coordinate of the center of the Gaussian distribution.center_y (
int) – Y coordinate of the center of the Gaussian distribution.std_scale (
float) – Scaling factor for the standard deviation of the Gaussian distribution. The standard deviation of the Gaussian distribution will be(nrow-1)/std_scaleand(ncol-1)/std_scalealong the X and Y axes, respectively.mask (
ndarray) – A binary integer 2D array representing the input mask.output_mask (
ndarray) – A binary integer 2D array representing the output mask.seed (
int) – Seed for the random number generator.
- Return type:
ndarray- Returns:
A 2D array representing the output mask filled with randomly sampled positions following a 2D Gaussian distribution.
- direct.ssl.mask_fillers.uniform_fill(nonzero_mask_count, nrow, ncol, mask, rng)[source]#
Fills a binary
torch.Tensormask with the specified number of ones in a uniform random manner.- Parameters:
nonzero_mask_count (
int) – The number of 1s to place in the mask.nrow (
int) – The number of rows in the mask.ncol (
int) – The number of columns in the mask.mask (
Tensor) – A binary mask with zeros and ones.rng (
RandomState) – A NumPy random state object for reproducibility.
- Return type:
Tensor- Returns:
A binary mask with the specified number of 1s placed in a uniform random manner.
direct.ssl.ssl module#
Direct SSL mask splitters module.
This module contains classes for splitting masks for self-supervised learning tasks for MRI reconstruction.
For example, the GaussianMaskSplitterModule splits the input mask into two disjoint masks using a Gaussian
split scheme, while the UniformMaskSplitterModule splits the input mask into two disjoint masks using a
uniform split scheme. The HalfMaskSplitterModule splits the input mask into two disjoint masks in a half
line direction.
- class direct.ssl.ssl.GaussianMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE, std_scale=3.0)[source]#
Bases:
MaskSplitterUses Gaussian splitting method to split the input mask into two disjoint masks.
- Parameters:
ratio (
Union[float,list[float],tuple[float,...]]) – Split ratio such that: math:ratio approx frac{|A|}{|B|}. Default: 0.5.acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask.Default (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.std_scale (
float) – This is used to calculate the standard deviation of the Gaussian distribution. Default: 3.0.
- __init__(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE, std_scale=3.0)[source]#
Inits
GaussianMaskSplitterModule.- Parameters:
ratio (
Union[float,list[float],tuple[float,...]]) – Split ratio such that: math:ratio approx frac{|A|}{|B|}. Default: 0.5.acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of theDefault (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.std_scale (
float) – This is used to calculate the standard deviation of the Gaussian distribution. Default: 3.0.
- split_method(sampling_mask, acs_mask, seed)[source]#
Splits the sampling_mask into two disjoint masks based on gaussian split method.
- Parameters:
sampling_mask (
Tensor) – The input mask tensor to be split.acs_mask (
Optional[Tensor]) – The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it willDefault (be ignored.) – None.
seed (
Union[int,Iterable[int],None]) – Seed to generate split. Default: None.
- Return type:
tuple[Tensor,Tensor]- Returns:
The two disjoint masks, input_mask and target_mask.
- class direct.ssl.ssl.HalfMaskSplitterModule(acs_region=(0, 0), keep_acs=False, use_seed=True, direction=HalfSplitType.VERTICAL, kspace_key=KspaceKey.MASKED_KSPACE)[source]#
Bases:
MaskSplitterSplits the input mask into two disjoint masks in a half line direction.
- Parameters:
acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask.Default (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.
- __init__(acs_region=(0, 0), keep_acs=False, use_seed=True, direction=HalfSplitType.VERTICAL, kspace_key=KspaceKey.MASKED_KSPACE)[source]#
Inits
HalfMaskSplitterModule.- Parameters:
acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of theDefault (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
direction (
HalfSplitType) – Direction of the half line split. Default: HalfSplitType.VERTICAL.kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.
- split_method(sampling_mask, acs_mask, seed)[source]#
Splits the sampling_mask into two disjoint masks based on half split method.
- Parameters:
sampling_mask (
Tensor) – The input mask tensor to be split.acs_mask (
Optional[Tensor]) – The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it willDefault (be ignored.) – None.
seed (
Union[int,Iterable[int],None]) – Seed to generate split. Default: None.
- Return type:
tuple[Tensor,Tensor]- Returns:
The two disjoint masks, input_mask and target_mask.
- class direct.ssl.ssl.HalfSplitType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnumSSL half mask splitter types.
These are used to define the type of half mask splitting.
- HORIZONTAL = 'horizontal'#
- VERTICAL = 'vertical'#
- DIAGONAL_LEFT = 'diagonal_left'#
- DIAGONAL_RIGHT = 'diagonal_right'#
- class direct.ssl.ssl.MaskSplitterType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnumSSL mask splitter types.
These are used to define the type of mask splitting.
- UNIFORM = 'uniform'#
- GAUSSIAN = 'gaussian'#
- HALF = 'half'#
- class direct.ssl.ssl.UniformMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE)[source]#
Bases:
MaskSplitterUses Uniform splitting method to split the input mask into two disjoint masks.
- Parameters:
ratio (
Union[float,list[float],tuple[float,...]]) – Split ratio such that: math:ratio approx frac{|A|}{|B|}. Default: 0.5.acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask.Default (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.
- __init__(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE)[source]#
Inits
UniformMaskSplitterModule.- Parameters:
ratio (
Union[float,list[float],tuple[float,...]]) – Split ratio such that: math:ratio approx frac{|A|}{|B|}. Default: 0.5.acs_region (
Union[list[int],tuple[int,int]]) – Size of ACS region to include in training (input) mask. Default: (0, 0).keep_acs (
bool) – If True, both input and target masks will keep the acs region and ratio will be applied on the rest of theDefault (the same mask every time.) – False.
use_seed (
bool) – If True, a pseudo-random number based on the filename is computed so that every slice of the volume getDefault – True.
kspace_key (
KspaceKey) – K-space key. Default “masked_kspace”.
- split_method(sampling_mask, acs_mask, seed)[source]#
Splits the sampling_mask into two disjoint masks based on the uniform split method.
- Parameters:
sampling_mask (
Tensor) – The input mask tensor to be split.acs_mask (
Optional[Tensor]) – The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it willDefault (be ignored.) – None.
seed (
Union[int,Iterable[int],None]) – Seed to generate split. Default: None.
- Return type:
tuple[Tensor,Tensor]- Returns:
The two disjoint masks, input_mask and target_mask.
- class direct.ssl.ssl.SSLTransformMaskPrefixes(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnumSSL Transform mask prefixes.
These are used to prefix the input (\(\Theta) and target (:math:\)Lambda`) k-spaces/masks in the sample.
- INPUT_ = 'input_'#
- TARGET_ = 'target_'#
Module contents#
Direct SSL module.
Includes functions for self-supervised learning for MRI reconstruction.