direct.nn.varnet package#

Submodules#

direct.nn.varnet.config module#

class direct.nn.varnet.config.EndToEndVarNetConfig(model_name: str = '???', engine_name: str | None = None, num_layers: int = 8, regularizer_num_filters: int = 18, regularizer_num_pull_layers: int = 4, regularizer_dropout: float = 0.0)[source][source]#

Bases: ModelConfig

num_layers: int = 8#
regularizer_dropout: float = 0.0#
regularizer_num_filters: int = 18#
regularizer_num_pull_layers: int = 4#

direct.nn.varnet.varnet module#

class direct.nn.varnet.varnet.EndToEndVarNet(forward_operator, backward_operator, num_layers, regularizer_num_filters=18, regularizer_num_pull_layers=4, regularizer_dropout=0.0, in_channels=2, **kwargs)[source][source]#

Bases: Module

End-to-End Variational Network based on [1].

References

[1]

Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.” ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688.

forward(masked_kspace, sampling_mask, sensitivity_map)[source][source]#

Performs the forward pass of EndToEndVarNet.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

Returns:
kspace_prediction: torch.Tensor

K-space prediction of shape (N, coil, height, width, complex=2).

Return type:

Tensor

training: bool#
class direct.nn.varnet.varnet.EndToEndVarNetBlock(forward_operator, backward_operator, regularizer_model)[source][source]#

Bases: Module

End-to-End Variational Network block.

forward(current_kspace, masked_kspace, sampling_mask, sensitivity_map)[source][source]#

Performs the forward pass of EndToEndVarNetBlock.

Parameters:
current_kspace: torch.Tensor

Current k-space prediction of shape (N, coil, height, width, complex=2).

masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

Returns:
torch.Tensor

Next k-space prediction of shape (N, coil, height, width, complex=2).

Return type:

Tensor

training: bool#

direct.nn.varnet.varnet_engine module#

Engines for End-to-End Variational Network model.

Includes supervised, self-supervised and joint supervised and self-supervised learning engines.

class direct.nn.varnet.varnet_engine.EndToEndVarNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

End-to-End Variational Network Engine.

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

tuple[Tensor, Tensor]

class direct.nn.varnet.varnet_engine.EndToEndVarNetJSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: JSSLMRIModelEngine

Joint Supervised and Self-supervised Learning End-to-End Variational Network Engine.

Used for supplementary experiments for End-to-End Variational Network model with JSLL in the JSSL paper [1].

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

References

[1]

Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.

forward_function(data)[source][source]#

Forward function for EndToEndVarNetJSSLEngine.

Parameters:
datadict[str, Any]

Data dictionary. Should contain the following keys: - “is_ssl” boolean tensor indicating if training is SSL - “input_kspace” if training and training is SSL, “masked_kspace” if inference - “input_sampling_mask” if training and training is SSL, “sampling_mask” if inference - “sensitivity_map”

Returns:
tuple[None, torch.Tensor]

None for image and output k-space.

Return type:

tuple[None, Tensor]

class direct.nn.varnet.varnet_engine.EndToEndVarNetSSLEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: SSLMRIModelEngine

Self-supervised Learning End-to-End Variational Network Engine.

Used for supplementary experiments for End-to-End Variational Network model with SLL in the JSSL paper [1].

Parameters:
cfg: BaseConfig

Configuration file.

model: nn.Module

Model.

device: str

Device. Can be “cuda:{idx}” or “cpu”.

forward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The forward operator. Default: None.

backward_operator: Callable[[tuple[Any, …]], torch.Tensor], optional

The backward operator. Default: None.

mixed_precision: bool

Use mixed precision. Default: False.

**models: nn.Module

Additional models.

References

[1]

Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). https://doi.org/10.48550/arXiv.2311.15856.

forward_function(data)[source][source]#

Forward function for EndToEndVarNetSSLEngine.

Parameters:
datadict[str, Any]

Data dictionary. Should contain the following keys: - “input_kspace” if training, “masked_kspace” if inference - “input_sampling_mask” if training, “sampling_mask” if inference - “sensitivity_map”

Returns:
tuple[None, torch.Tensor]

None for image and output k-space.

Return type:

tuple[None, Tensor]

Module contents#