direct.ssl package#

Submodules#

direct.ssl.mask_fillers module#

SSL Mask Fillers.

This module contains functions for splitting binary masks into (disjoint) subsets to be used for self-supervised learning MRI reconstruction tasks.

direct.ssl.mask_fillers.gaussian_fill(nonzero_mask_count, nrow, ncol, center_x, center_y, std_scale, mask, output_mask, seed)[source][source]#

Generates a binary mask filled with randomly sampled positions following a 2D Gaussian distribution.

Makes a call to the cython function _gaussian_fill.

Parameters:
nonzero_mask_countint

Number of non-zero entries in the output mask.

nrowint

Number of rows of the output mask.

ncolint

Number of columns of the output mask.

center_xint

X coordinate of the center of the Gaussian distribution.

center_yint

Y coordinate of the center of the Gaussian distribution.

std_scalefloat

Scaling factor for the standard deviation of the Gaussian distribution. The standard deviation of the Gaussian distribution will be (nrow-1)/std_scale and (ncol-1)/std_scale along the X and Y axes, respectively.

masknp.ndarray

A binary integer 2D array representing the input mask.

output_masknp.ndarray

A binary integer 2D array representing the output mask.

seedint

Seed for the random number generator.

Returns:
np.ndarray

A 2D array representing the output mask filled with randomly sampled positions following a 2D Gaussian distribution.

Return type:

ndarray

direct.ssl.mask_fillers.uniform_fill(nonzero_mask_count, nrow, ncol, mask, rng)[source][source]#

Fills a binary torch.Tensor mask with the specified number of ones in a uniform random manner.

Parameters:
nonzero_mask_countint

The number of 1s to place in the mask.

nrowint

The number of rows in the mask.

ncolint

The number of columns in the mask.

masktorch.Tensor

A binary mask with zeros and ones.

rngnp.random.RandomState

A NumPy random state object for reproducibility.

Returns:
torch.Tensor

A binary mask with the specified number of 1s placed in a uniform random manner.

Return type:

Tensor

direct.ssl.ssl module#

Direct SSL mask splitters module.

This module contains classes for splitting masks for self-supervised learning tasks for MRI reconstruction. For example, the GaussianMaskSplitterModule splits the input mask into two disjoint masks using a Gaussian split scheme, while the UniformMaskSplitterModule splits the input mask into two disjoint masks using a uniform split scheme. The HalfMaskSplitterModule splits the input mask into two disjoint masks in a half line direction.

class direct.ssl.ssl.GaussianMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE, std_scale=3.0)[source][source]#

Bases: MaskSplitter

Uses Gaussian splitting method to split the input mask into two disjoint masks.

Parameters:
ratiofloat, list[float] or tuple[float, …], optional

Split ratio such that \(ratio \approx \frac{|A|}{|B|}\). Default: 0.5.

acs_regionlist[int] or tuple[int, int], optional

Size of ACS region to include in training (input) mask. Default: (0, 0).

keep_acsbool, optional

If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.

use_seedbool, optional

If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.

kspace_keystr, optional

K-space key. Default “masked_kspace”.

std_scalefloat, optional

This is used to calculate the standard deviation of the Gaussian distribution. Default: 3.0.

split_method(sampling_mask, acs_mask, seed)[source][source]#

Splits the sampling_mask into two disjoint masks based on gaussian split method.

Parameters:
sampling_masktorch.Tensor

The input mask tensor to be split.

acs_masktorch.Tensor or None

The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.

seedint, iterable of ints or None

Seed to generate split.

Returns:
tuple[torch.Tensor, torch.Tensor]:

The two disjoint masks, input_mask and target_mask.

Return type:

tuple[Tensor, Tensor]

training: bool#
class direct.ssl.ssl.HalfMaskSplitterModule(acs_region=(0, 0), keep_acs=False, use_seed=True, direction=HalfSplitType.VERTICAL, kspace_key=KspaceKey.MASKED_KSPACE)[source][source]#

Bases: MaskSplitter

Splits the input mask into two disjoint masks in a half line direction.

Parameters:
acs_region: list[int] or tuple[int, int], optional

Size of ACS region to include in training (input) mask. Default: (0, 0).

keep_acs: bool, optional

If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.

use_seed: bool, optional

If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.

kspace_key: str, optional

K-space key. Default “masked_kspace”.

split_method(sampling_mask, acs_mask, seed)[source][source]#

Splits the sampling_mask into two disjoint masks based on gaussian split method.

Parameters:
sampling_masktorch.Tensor

The input mask tensor to be split.

acs_masktorch.Tensor or None

The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.

seedint, iterable of ints or None

Seed to generate split.

Returns:
tuple[torch.Tensor, torch.Tensor]:

The two disjoint masks, input_mask and target_mask.

Return type:

tuple[Tensor, Tensor]

training: bool#
class direct.ssl.ssl.HalfSplitType(value)[source][source]#

Bases: DirectEnum

SSL half mask splitter types.

These are used to define the type of half mask splitting.

DIAGONAL_LEFT = 'diagonal_left'#
DIAGONAL_RIGHT = 'diagonal_right'#
HORIZONTAL = 'horizontal'#
VERTICAL = 'vertical'#
class direct.ssl.ssl.MaskSplitterType(value)[source][source]#

Bases: DirectEnum

SSL mask splitter types.

These are used to define the type of mask splitting.

GAUSSIAN = 'gaussian'#
HALF = 'half'#
UNIFORM = 'uniform'#
class direct.ssl.ssl.SSLTransformMaskPrefixes(value)[source][source]#

Bases: DirectEnum

SSL Transform mask prefixes.

These are used to prefix the input (\(\Theta) and target (:math:\)Lambda`) k-spaces/masks in the sample.

INPUT_ = 'input_'#
TARGET_ = 'target_'#
class direct.ssl.ssl.UniformMaskSplitterModule(ratio=0.5, acs_region=(0, 0), keep_acs=False, use_seed=True, kspace_key=KspaceKey.MASKED_KSPACE)[source][source]#

Bases: MaskSplitter

Uses Uniform splitting method to split the input mask into two disjoint masks.

Parameters:
ratiofloat, list[float] or tuple[float, …], optional

Split ratio such that \(ratio \approx \frac{|A|}{|B|}\). Default: 0.5.

acs_regionlist[int] or tuple[int, int], optional

Size of ACS region to include in training (input) mask. Default: (0, 0).

keep_acsbool, optional

If True, both input and target masks will keep the acs region and ratio will be applied on the rest of the mask. Assumes acs_mask is present in the sample. Default: False.

use_seedbool, optional

If True, a pseudo-random number based on the filename is computed so that every slice of the volume get the same mask every time. Default: True.

kspace_keystr, optional

K-space key. Default “masked_kspace”.

split_method(sampling_mask, acs_mask, seed)[source][source]#

Splits the sampling_mask into two disjoint masks based on the uniform split method.

Parameters:
sampling_masktorch.Tensor

The input mask tensor to be split.

acs_masktorch.Tensor or None

The ACS mask. Needs to be passed if keep_acs is True. If keep_acs is False but this is passed, it will be ignored. Default: None.

seedint, iterable of ints or None

Seed to generate split.

Returns:
tuple[torch.Tensor, torch.Tensor]:

The two disjoint masks, input_mask and target_mask.

Return type:

tuple[Tensor, Tensor]

training: bool#

Module contents#

Direct SSL module.

Includes functions for self-supervised learning for MRI reconstruction.