direct.nn.crossdomain package#

Submodules#

direct.nn.crossdomain.crossdomain module#

class direct.nn.crossdomain.crossdomain.CrossDomainNetwork(forward_operator, backward_operator, image_model_list, kspace_model_list=None, domain_sequence='KIKI', image_buffer_size=1, kspace_buffer_size=1, normalize_image=False, **kwargs)[source][source]#

Bases: Module

This performs optimisation in both, k-space (“K”) and image (“I”) domains according to domain_sequence.

forward(masked_kspace, sampling_mask, sensitivity_map, scaling_factor=None)[source][source]#

Computes the forward pass of CrossDomainNetwork.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sampling_mask: torch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

scaling_factor: Optional[torch.Tensor]

Scaling factor of shape (N,). If None, no scaling is applied. Default: None.

Returns:
out_image: torch.Tensor

Output image of shape (N, height, width, complex=2).

Return type:

Tensor

image_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map)[source][source]#
Return type:

Tensor

kspace_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace)[source][source]#
Return type:

Tensor

training: bool#

direct.nn.crossdomain.multicoil module#

class direct.nn.crossdomain.multicoil.MultiCoil(model, coil_dim=1, coil_to_batch=False)[source][source]#

Bases: Module

This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model.

If coil_to_batch is set to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model each coil-data individually.

forward(x)[source][source]#

Performs the forward pass of MultiCoil.

Parameters:
x: torch.Tensor

Multi-coil input of shape (N, coil, height, width, in_channels).

Returns:
out: torch.Tensor

Multi-coil output of shape (N, coil, height, width, out_channels).

Return type:

Tensor

training: bool#

Module contents#