direct.nn.conjgradnet package#

Submodules#

direct.nn.conjgradnet.config module#

class direct.nn.conjgradnet.config.ConjGradNetConfig(model_name='???', engine_name=None, num_steps=8, image_init=InitType.ZEROS, no_parameter_sharing=True, cg_tol=1e-07, cg_iters=10, cg_param_update_type=CGUpdateType.FR, denoiser_architecture=ModelName.RESNET, resnet_hidden_channels=128, resnet_num_blocks=15, resenet_batchnorm=True, resenet_scale=0.1, unet_num_filters=32, unet_num_pool_layers=4, unet_dropout=0.0, didn_hidden_channels=16, didn_num_dubs=6, didn_num_convs_recon=9, conv_hidden_channels=64, conv_n_convs=15, conv_activation=ActivationType.RELU, conv_batchnorm=False)[source]#

Bases: ModelConfig

num_steps = 8#
image_init = 'zeros'#
no_parameter_sharing = True#
cg_tol = 1e-07#
cg_iters = 10#
cg_param_update_type = 'FR'#
denoiser_architecture = 'resnet'#
resnet_hidden_channels = 128#
resnet_num_blocks = 15#
resenet_batchnorm = True#
resenet_scale = 0.1#
unet_num_filters = 32#
unet_num_pool_layers = 4#
unet_dropout = 0.0#
didn_hidden_channels = 16#
didn_num_dubs = 6#
didn_num_convs_recon = 9#
conv_hidden_channels = 64#
conv_n_convs = 15#
conv_activation = 'relu'#
conv_batchnorm = False#
__init__(model_name='???', engine_name=None, num_steps=8, image_init=InitType.ZEROS, no_parameter_sharing=True, cg_tol=1e-07, cg_iters=10, cg_param_update_type=CGUpdateType.FR, denoiser_architecture=ModelName.RESNET, resnet_hidden_channels=128, resnet_num_blocks=15, resenet_batchnorm=True, resenet_scale=0.1, unet_num_filters=32, unet_num_pool_layers=4, unet_dropout=0.0, didn_hidden_channels=16, didn_num_dubs=6, didn_num_convs_recon=9, conv_hidden_channels=64, conv_n_convs=15, conv_activation=ActivationType.RELU, conv_batchnorm=False)#

direct.nn.conjgradnet.conjgrad module#

class direct.nn.conjgradnet.conjgrad.CGUpdateType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: DirectEnum

FR = 'FR'#
PRP = 'PRP'#
DY = 'DY'#
BAN = 'BAN'#
class direct.nn.conjgradnet.conjgrad.ConjGrad(forward_operator, backward_operator, num_iters=10, tol=1e-6, bk_update_type=CGUpdateType.FR)[source]#

Bases: Module

Performs the Conjugate Gradient (CG) algorithm to approach a solution to:

\[\min_{x} f(x) = \min_{x} \fraq{1}{2} \big( ||\mathcal{A}(x) - y||_2^2 + \lambda ||x - z||_2^2 \big)\]

or equivalently solving the normal equation of the above:

\[\mathcal{B}(x): = \big(\mathcal{A} \circ \mathcal{A}^{*} + \lambda I\big) (x) = \mathcal{A}^{*}(y) + \lambda z =: b.\]

Note

ConjGrad has no trainable parameters. However, PyTorch ensures that gradients are computed.

__init__(forward_operator, backward_operator, num_iters=10, tol=1e-6, bk_update_type=CGUpdateType.FR)[source]#

Inits ConjGrad.

Parameters:
  • forward_operator (Callable) – Forward operator \(\mathcal{A}\) (e.g. fft).

  • backward_operator (Callable) – Backward/adjoint operator \(\mathcal{A}^{*}\) (e.g. ifft).

  • num_iters (int) – Convergence criterion 1: number of CG iterations. Default: 10.

  • tol (float) – Convergence criterion 2: checks if CG has converged by checking r_k norm. Default: 1e-6.

  • bk_update_type (CGUpdateType) – How to compute \(b_k\). Can be "FR", "PRP", "DY" and "BAN". Default: CGUpdateType.FR.

B_op(x, sensitivity_map, sampling_mask, lambd)[source]#

Computes \(\mathcal{B}(x) = (\mathcal{A}^{*} \circ \mathcal{A}+ \lambda I) (x)\)

Parameters:
  • x (Tensor) – torch.Tensor

  • shape (Regularaziation parameter of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • lambd (Tensor) – torch.Tensor

  • shape

Return type:

Tensor

Returns:

torch.Tensor

cg(x, y, sensitivity_map, sampling_mask, lambd, z)[source]#

Computes the conjugate gradient algorithm.

Parameters:
  • x (Tensor) – torch.Tensor

  • for (Guess) – math:x_0 of shape (N, height, width, complex=2).

  • y (Tensor) – torch.Tensor

  • shape (Denoised input of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • lambd (Tensor) – torch.Tensor

  • shape

  • z (Tensor) – torch.Tensor

  • shape

Return type:

Tensor

Returns:

torch.Tensor x_K.

forward(masked_kspace, sensitivity_map, sampling_mask, z, lambd)[source]#

Performs forward pass of ConjGrad.

Parameters:
  • masked_kspace (Tensor) – torch.Tensor

  • shape (Prediction of image of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape

  • sampling_mask (Tensor) – torch.Tensor

  • shape

  • z (Tensor) – torch.Tensor

  • shape

  • lambd (Tensor) – torch.Tensor

  • Regularaziation (trainable or not) parameter of shape (1)

Return type:

Tensor

Returns:

torch.Tensor

direct.nn.conjgradnet.conjgradnet module#

class direct.nn.conjgradnet.conjgradnet.ConjGradNet(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-7, cg_param_update_type=CGUpdateType.FR, **kwargs)[source]#

Bases: Module

Conjugate Gradient Network for MRI Reconstruction.

Solves iteratively the following:

\[z^i = \arg \min_{z} \mu ||x^i - z||_2^2 + \mathcal{R}(z) x^i+1 = \arg \min_{x} ||A(x) - y||_2^2 + \mu ||x - z^i||_2^2\]

where A is the forward operator of Accelerated MRI Reconstruction. The former equation is solved by a denoiser \(D_{i_\theta}\) who takes as input \(x^i\) and the former is solved by the conjugate gradient algorithm [1]_.

References: .. [1] Jonathan Richard Shewchuk (1994) An introduction to the conjugate gradient method without the agonizing pain.

__init__(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-7, cg_param_update_type=CGUpdateType.FR, **kwargs)[source]#

Inits ConjGradNet.

Parameters:
  • forward_operator (Callable) – Callable

  • Operator. (Backward)

  • backward_operator (Callable) – Callable

  • Operator.

  • num_steps (int) – int

  • steps. (Number of unrolled optimization)

  • denoiser_architecture (Key word arguments should include denoiser architecture parameters. For example if) – ModelName

  • "resnet" (Type of architecture to use as a denoiser. Can be)

  • "unet"

  • "normunet"

  • "conv". ("didn" or)

  • Default (Convergence tolerance for conjugate gradient.) – “resnet”.

  • image_init (InitType) – InitType

  • "sense" (Initialization type for z. Can be) – “zeros”.

  • Default – “zeros”.

  • no_parameter_sharing (bool) – bool

  • False (If) – True.

  • Default – True.

  • cg_iters (int) – int

  • Defualt (Number of maximum conjugate gradient iterations.) –

  • cg_tol (float) – float

  • Default – 1e-7.

  • cg_param_update_type (CGUpdateType) – CGUpdateType

  • compute (How to) – math:b_k in conjugate gradient. Can be “FR”, “PRP”, “DY” and “BAN”. Default “FR”.

  • kwargs – dictionary

  • denoiser_architecture

  • unet_num_filters (is "unet" or "norm_unet" then)

  • unet_dropout_probability (unet_num_pool_layers and)

  • passed. (should be)

forward(masked_kspace, sensitivity_map, sampling_mask)[source]#

Computes forward pass of ConjGradNet.

Parameters:
  • masked_kspace (Tensor) – torch.Tensor

  • shape (Sensitivity map of)

  • sensitivity_map (Tensor) – torch.Tensor

  • shape – None.

  • sampling_mask (Tensor) – torch.Tensor

Returns:

torch.Tensor Output image of shape (N, height, width, complex=2).

Return type:

image

static init_z(image_init, backward_operator, kspace, coil_dim, spatial_dims, sensitivity_map=None)[source]#
Return type:

Tensor

direct.nn.conjgradnet.conjgradnet_engine module#

class direct.nn.conjgradnet.conjgradnet_engine.ConjGradNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Bases: MRIModelEngine

__init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#

Inits ConjGradNetEngine.

Parameters:
  • cfg (BaseConfig) – Configuration file.

  • model (Module) – Model.

  • device (str) – Device. Can be “cuda: {idx}” or “cpu”.

  • forward_operator (Optional[Callable]) – The forward operator. Default: None.

  • backward_operator (Optional[Callable]) – The backward operator. Default: None.

  • mixed_precision (bool) – Use mixed precision. Default: False.

  • **models (Module) – Additional models.

forward_function(data)[source]#

This method performs the model’s forward method given data which contains all tensor inputs.

Must be implemented by child classes.

Return type:

Tuple[Tensor, None]

Module contents#