direct.nn.crossdomain package#
Submodules#
direct.nn.crossdomain.crossdomain module#
- class direct.nn.crossdomain.crossdomain.CrossDomainNetwork(forward_operator, backward_operator, image_model_list, kspace_model_list=None, domain_sequence='KIKI', image_buffer_size=1, kspace_buffer_size=1, normalize_image=False, **kwargs)[source][source]#
Bases:
Module
This performs optimisation in both, k-space (“K”) and image (“I”) domains according to domain_sequence.
- forward(masked_kspace, sampling_mask, sensitivity_map, scaling_factor=None)[source][source]#
Computes the forward pass of
CrossDomainNetwork
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- scaling_factor: Optional[torch.Tensor]
Scaling factor of shape (N,). If None, no scaling is applied. Default: None.
- Returns:
- out_image: torch.Tensor
Output image of shape (N, height, width, complex=2).
- Return type:
Tensor
- image_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map)[source][source]#
- Return type:
Tensor
- kspace_correction(block_idx, image_buffer, kspace_buffer, sampling_mask, sensitivity_map, masked_kspace)[source][source]#
- Return type:
Tensor
-
training:
bool
#
direct.nn.crossdomain.multicoil module#
- class direct.nn.crossdomain.multicoil.MultiCoil(model, coil_dim=1, coil_to_batch=False)[source][source]#
Bases:
Module
This makes the forward pass of multi-coil data of shape (N, N_coils, H, W, C) to a model.
If coil_to_batch is set to True, coil dimension is moved to the batch dimension. Otherwise, it passes to the model each coil-data individually.
- forward(x)[source][source]#
Performs the forward pass of MultiCoil.
- Parameters:
- x: torch.Tensor
Multi-coil input of shape (N, coil, height, width, in_channels).
- Returns:
- out: torch.Tensor
Multi-coil output of shape (N, coil, height, width, out_channels).
- Return type:
Tensor
-
training:
bool
#