direct.nn.conjgradnet package#

Submodules#

direct.nn.conjgradnet.config module#

class direct.nn.conjgradnet.config.ConjGradNetConfig(model_name: str = '???', engine_name: Optional[str] = None, num_steps: int = 8, image_init: str = <InitType.ZEROS: 'zeros'>, no_parameter_sharing: bool = True, cg_tol: float = 1e-07, cg_iters: int = 10, cg_param_update_type: str = <CGUpdateType.FR: 'FR'>, denoiser_architecture: str = <ModelName.RESNET: 'resnet'>, resnet_hidden_channels: int = 128, resnet_num_blocks: int = 15, resenet_batchnorm: bool = True, resenet_scale: Optional[float] = 0.1, unet_num_filters: Optional[int] = 32, unet_num_pool_layers: Optional[int] = 4, unet_dropout: Optional[float] = 0.0, didn_hidden_channels: Optional[int] = 16, didn_num_dubs: Optional[int] = 6, didn_num_convs_recon: Optional[int] = 9, conv_hidden_channels: Optional[int] = 64, conv_n_convs: Optional[int] = 15, conv_activation: Optional[str] = <ActivationType.RELU: 'relu'>, conv_batchnorm: Optional[bool] = False)[source][source]#

Bases: ModelConfig

cg_iters: int = 10#
cg_param_update_type: str = 'FR'#
cg_tol: float = 1e-07#
conv_activation: Optional[str] = 'relu'#
conv_batchnorm: Optional[bool] = False#
conv_hidden_channels: Optional[int] = 64#
conv_n_convs: Optional[int] = 15#
denoiser_architecture: str = 'resnet'#
didn_hidden_channels: Optional[int] = 16#
didn_num_convs_recon: Optional[int] = 9#
didn_num_dubs: Optional[int] = 6#
image_init: str = 'zeros'#
no_parameter_sharing: bool = True#
num_steps: int = 8#
resenet_batchnorm: bool = True#
resenet_scale: Optional[float] = 0.1#
resnet_hidden_channels: int = 128#
resnet_num_blocks: int = 15#
unet_dropout: Optional[float] = 0.0#
unet_num_filters: Optional[int] = 32#
unet_num_pool_layers: Optional[int] = 4#

direct.nn.conjgradnet.conjgrad module#

class direct.nn.conjgradnet.conjgrad.CGUpdateType(value)[source][source]#

Bases: DirectEnum

An enumeration.

BAN = 'BAN'#
DY = 'DY'#
FR = 'FR'#
PRP = 'PRP'#
class direct.nn.conjgradnet.conjgrad.ConjGrad(forward_operator, backward_operator, num_iters=10, tol=1e-06, bk_update_type=CGUpdateType.FR)[source][source]#

Bases: Module

Performs the Conjugate Gradient (CG) algorithm to approach a solution to:

\[\min_{x} f(x) = \min_{x} \fraq{1}{2} \big( ||\mathcal{A}(x) - y||_2^2 + \lambda ||x - z||_2^2 \big)\]

or equivalently solving the normal equation of the above:

\[\mathcal{B}(x): = \big(\mathcal{A} \circ \mathcal{A}^{*} + \lambda I\big) (x) = \mathcal{A}^{*}(y) + \lambda z =: b.\]

Notes

ConjGrad has no trainable parameters. However, PyTorch ensures that gradients are computed.

B_op(x, sensitivity_map, sampling_mask, lambd)[source][source]#

Computes \(\mathcal{B}(x) = (\mathcal{A}^{*} \circ \mathcal{A}+ \lambda I) (x)\)

Parameters:
xtorch.Tensor

Image of shape (N, height, width, complex=2).

sensitivity_maptorch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

sampling_masktorch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

lambdtorch.Tensor

Regularaziation parameter of shape (1).

Returns:
torch.Tensor
Return type:

Tensor

cg(x, y, sensitivity_map, sampling_mask, lambd, z)[source][source]#

Computes the conjugate gradient algorithm.

Parameters:
xtorch.Tensor

Guess for \(x_0\) of shape (N, height, width, complex=2).

ytorch.Tensor

Initial/masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_maptorch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2).

sampling_masktorch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

lambdtorch.Tensor

Regularaziation parameter of shape (1).

ztorch.Tensor

Denoised input of shape (N, height, width, complex=2).

Returns:
torch.Tensor

x_K.

Return type:

Tensor

forward(masked_kspace, sensitivity_map, sampling_mask, z, lambd)[source][source]#

Performs forward pass of ConjGrad.

Parameters:
masked_kspacetorch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_maptorch.Tensor

Coil sensitivities of shape (N, coil, height, width, complex=2).

sampling_masktorch.Tensor

Sampling mask of shape (N, 1, height, width, 1).

ztorch.Tensor

Prediction of image of shape (N, height, width, complex=2).

lambdtorch.Tensor

Regularaziation (trainable or not) parameter of shape (1).

Returns:
torch.Tensor
Return type:

Tensor

training: bool#

direct.nn.conjgradnet.conjgradnet module#

class direct.nn.conjgradnet.conjgradnet.ConjGradNet(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-07, cg_param_update_type=CGUpdateType.FR, **kwargs)[source][source]#

Bases: Module

Conjugate Gradient Network for MRI Reconstruction.

Solves iteratively the following:

\[z^i = \arg \min_{z} \mu ||x^i - z||_2^2 + \mathcal{R}(z) x^i+1 = \arg \min_{x} ||A(x) - y||_2^2 + \mu ||x - z^i||_2^2\]

where A is the forward operator of Accelerated MRI Reconstruction. The former equation is solved by a denoiser \(D_{i_\theta}\) who takes as input \(x^i\) and the former is solved by the conjugate gradient algorithm [1].

References

[1]

Jonathan Richard Shewchuk (1994) An introduction to the conjugate gradient method without the agonizing pain. Available at: https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf.

forward(masked_kspace, sensitivity_map, sampling_mask)[source][source]#

Computes forward pass of ConjGradNet.

Parameters:
masked_kspace: torch.Tensor

Masked k-space of shape (N, coil, height, width, complex=2).

sensitivity_map: torch.Tensor

Sensitivity map of shape (N, coil, height, width, complex=2). Default: None.

sampling_mask: torch.Tensor
Returns:
image: torch.Tensor

Output image of shape (N, height, width, complex=2).

Return type:

Tensor

static init_z(image_init, backward_operator, kspace, coil_dim, spatial_dims, sensitivity_map=None)[source][source]#
Return type:

Tensor

training: bool#

direct.nn.conjgradnet.conjgradnet_engine module#

class direct.nn.conjgradnet.conjgradnet_engine.ConjGradNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source][source]#

Bases: MRIModelEngine

forward_function(data)[source][source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#