direct.nn.conjgradnet package#
Submodules#
direct.nn.conjgradnet.config module#
- class direct.nn.conjgradnet.config.ConjGradNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: int = 8, image_init: str = <InitType.ZEROS: 'zeros'>, no_parameter_sharing: bool = True, cg_tol: float = 1e-07, cg_iters: int = 10, cg_param_update_type: str = <CGUpdateType.FR: 'FR'>, denoiser_architecture: str = <ModelName.RESNET: 'resnet'>, resnet_hidden_channels: int = 128, resnet_num_blocks: int = 15, resenet_batchnorm: bool = True, resenet_scale: Optional[float] = 0.1, unet_num_filters: Optional[int] = 32, unet_num_pool_layers: Optional[int] = 4, unet_dropout: Optional[float] = 0.0, didn_hidden_channels: Optional[int] = 16, didn_num_dubs: Optional[int] = 6, didn_num_convs_recon: Optional[int] = 9, conv_hidden_channels: Optional[int] = 64, conv_n_convs: Optional[int] = 15, conv_activation: Optional[str] = <ActivationType.RELU: 'relu'>, conv_batchnorm: Optional[bool] = False)[source][source]#
Bases:
ModelConfig
-
cg_iters:
int
= 10#
-
cg_param_update_type:
str
= 'FR'#
-
cg_tol:
float
= 1e-07#
-
conv_activation:
Optional
[str
] = 'relu'#
-
conv_batchnorm:
Optional
[bool
] = False#
-
conv_n_convs:
Optional
[int
] = 15#
-
denoiser_architecture:
str
= 'resnet'#
-
didn_num_convs_recon:
Optional
[int
] = 9#
-
didn_num_dubs:
Optional
[int
] = 6#
-
image_init:
str
= 'zeros'#
-
no_parameter_sharing:
bool
= True#
-
num_steps:
int
= 8#
-
resenet_batchnorm:
bool
= True#
-
resenet_scale:
Optional
[float
] = 0.1#
-
resnet_num_blocks:
int
= 15#
-
unet_dropout:
Optional
[float
] = 0.0#
-
unet_num_filters:
Optional
[int
] = 32#
-
unet_num_pool_layers:
Optional
[int
] = 4#
-
cg_iters:
direct.nn.conjgradnet.conjgrad module#
- class direct.nn.conjgradnet.conjgrad.CGUpdateType(value)[source][source]#
Bases:
DirectEnum
An enumeration.
- BAN = 'BAN'#
- DY = 'DY'#
- FR = 'FR'#
- PRP = 'PRP'#
- class direct.nn.conjgradnet.conjgrad.ConjGrad(forward_operator, backward_operator, num_iters=10, tol=1e-06, bk_update_type=CGUpdateType.FR)[source][source]#
Bases:
Module
Performs the Conjugate Gradient (CG) algorithm to approach a solution to:
\[\min_{x} f(x) = \min_{x} \fraq{1}{2} \big( ||\mathcal{A}(x) - y||_2^2 + \lambda ||x - z||_2^2 \big)\]or equivalently solving the normal equation of the above:
\[\mathcal{B}(x): = \big(\mathcal{A} \circ \mathcal{A}^{*} + \lambda I\big) (x) = \mathcal{A}^{*}(y) + \lambda z =: b.\]Notes
ConjGrad
has no trainable parameters. However, PyTorch ensures that gradients are computed.- B_op(x, sensitivity_map, sampling_mask, lambd)[source][source]#
Computes \(\mathcal{B}(x) = (\mathcal{A}^{*} \circ \mathcal{A}+ \lambda I) (x)\)
- Parameters:
- xtorch.Tensor
Image of shape (N, height, width, complex=2).
- sensitivity_maptorch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- lambdtorch.Tensor
Regularaziation parameter of shape (1).
- Returns:
- torch.Tensor
- Return type:
Tensor
- cg(x, y, sensitivity_map, sampling_mask, lambd, z)[source][source]#
Computes the conjugate gradient algorithm.
- Parameters:
- xtorch.Tensor
Guess for \(x_0\) of shape (N, height, width, complex=2).
- ytorch.Tensor
Initial/masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_maptorch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- lambdtorch.Tensor
Regularaziation parameter of shape (1).
- ztorch.Tensor
Denoised input of shape (N, height, width, complex=2).
- Returns:
- torch.Tensor
x_K.
- Return type:
Tensor
- forward(masked_kspace, sensitivity_map, sampling_mask, z, lambd)[source][source]#
Performs forward pass of
ConjGrad
.- Parameters:
- masked_kspacetorch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_maptorch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
- sampling_masktorch.Tensor
Sampling mask of shape (N, 1, height, width, 1).
- ztorch.Tensor
Prediction of image of shape (N, height, width, complex=2).
- lambdtorch.Tensor
Regularaziation (trainable or not) parameter of shape (1).
- Returns:
- torch.Tensor
- Return type:
Tensor
-
training:
bool
#
direct.nn.conjgradnet.conjgradnet module#
- class direct.nn.conjgradnet.conjgradnet.ConjGradNet(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-07, cg_param_update_type=CGUpdateType.FR, **kwargs)[source][source]#
Bases:
Module
Conjugate Gradient Network for MRI Reconstruction.
Solves iteratively the following:
\[z^i = \arg \min_{z} \mu ||x^i - z||_2^2 + \mathcal{R}(z) x^i+1 = \arg \min_{x} ||A(x) - y||_2^2 + \mu ||x - z^i||_2^2\]where A is the forward operator of Accelerated MRI Reconstruction. The former equation is solved by a denoiser \(D_{i_\theta}\) who takes as input \(x^i\) and the former is solved by the conjugate gradient algorithm [1].
References
[1]Jonathan Richard Shewchuk (1994) An introduction to the conjugate gradient method without the agonizing pain. Available at: https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf.
- forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#
Computes forward pass of
ConjGradNet
.- Parameters:
- masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, height, width, complex=2).
- sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.
- sampling_mask: torch.Tensor
- Returns:
- image: torch.Tensor
Output image of shape (N, height, width, complex=2).
- Return type:
Tensor
- static init_z(image_init, backward_operator, kspace, coil_dim, spatial_dims, sensitivity_map=None)[source][source]#
- Return type:
Tensor
-
training:
bool
#
direct.nn.conjgradnet.conjgradnet_engine module#
- class direct.nn.conjgradnet.conjgradnet_engine.ConjGradNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#
Bases:
MRIModelEngine