direct.nn.conjgradnet package#
Submodules#
direct.nn.conjgradnet.config module#
- class direct.nn.conjgradnet.config.ConjGradNetConfig(model_name='???', engine_name=None, num_steps=8, image_init=InitType.ZEROS, no_parameter_sharing=True, cg_tol=1e-07, cg_iters=10, cg_param_update_type=CGUpdateType.FR, denoiser_architecture=ModelName.RESNET, resnet_hidden_channels=128, resnet_num_blocks=15, resenet_batchnorm=True, resenet_scale=0.1, unet_num_filters=32, unet_num_pool_layers=4, unet_dropout=0.0, didn_hidden_channels=16, didn_num_dubs=6, didn_num_convs_recon=9, conv_hidden_channels=64, conv_n_convs=15, conv_activation=ActivationType.RELU, conv_batchnorm=False)[source]#
Bases:
ModelConfig- num_steps = 8#
- image_init = 'zeros'#
- no_parameter_sharing = True#
- cg_tol = 1e-07#
- cg_iters = 10#
- cg_param_update_type = 'FR'#
- denoiser_architecture = 'resnet'#
- resnet_num_blocks = 15#
- resenet_batchnorm = True#
- resenet_scale = 0.1#
- unet_num_filters = 32#
- unet_num_pool_layers = 4#
- unet_dropout = 0.0#
- didn_num_dubs = 6#
- didn_num_convs_recon = 9#
- conv_n_convs = 15#
- conv_activation = 'relu'#
- conv_batchnorm = False#
- __init__(model_name='???', engine_name=None, num_steps=8, image_init=InitType.ZEROS, no_parameter_sharing=True, cg_tol=1e-07, cg_iters=10, cg_param_update_type=CGUpdateType.FR, denoiser_architecture=ModelName.RESNET, resnet_hidden_channels=128, resnet_num_blocks=15, resenet_batchnorm=True, resenet_scale=0.1, unet_num_filters=32, unet_num_pool_layers=4, unet_dropout=0.0, didn_hidden_channels=16, didn_num_dubs=6, didn_num_convs_recon=9, conv_hidden_channels=64, conv_n_convs=15, conv_activation=ActivationType.RELU, conv_batchnorm=False)#
direct.nn.conjgradnet.conjgrad module#
- class direct.nn.conjgradnet.conjgrad.CGUpdateType(value, names=_not_given, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
DirectEnum- FR = 'FR'#
- PRP = 'PRP'#
- DY = 'DY'#
- BAN = 'BAN'#
- class direct.nn.conjgradnet.conjgrad.ConjGrad(forward_operator, backward_operator, num_iters=10, tol=1e-6, bk_update_type=CGUpdateType.FR)[source]#
Bases:
ModulePerforms the Conjugate Gradient (CG) algorithm to approach a solution to:
\[\min_{x} f(x) = \min_{x} \fraq{1}{2} \big( ||\mathcal{A}(x) - y||_2^2 + \lambda ||x - z||_2^2 \big)\]or equivalently solving the normal equation of the above:
\[\mathcal{B}(x): = \big(\mathcal{A} \circ \mathcal{A}^{*} + \lambda I\big) (x) = \mathcal{A}^{*}(y) + \lambda z =: b.\]Note
ConjGradhas no trainable parameters. However, PyTorch ensures that gradients are computed.- __init__(forward_operator, backward_operator, num_iters=10, tol=1e-6, bk_update_type=CGUpdateType.FR)[source]#
Inits
ConjGrad.- Parameters:
forward_operator (
Callable) – Forward operator \(\mathcal{A}\) (e.g. fft).backward_operator (
Callable) – Backward/adjoint operator \(\mathcal{A}^{*}\) (e.g. ifft).num_iters (
int) – Convergence criterion 1: number of CG iterations. Default:10.tol (
float) – Convergence criterion 2: checks if CG has converged by checkingr_knorm. Default:1e-6.bk_update_type (
CGUpdateType) – How to compute \(b_k\). Can be"FR","PRP","DY"and"BAN". Default:CGUpdateType.FR.
- B_op(x, sensitivity_map, sampling_mask, lambd)[source]#
Computes \(\mathcal{B}(x) = (\mathcal{A}^{*} \circ \mathcal{A}+ \lambda I) (x)\)
- Parameters:
x (
Tensor) – torch.Tensorshape (Regularaziation parameter of)
sensitivity_map (
Tensor) – torch.Tensorshape
sampling_mask (
Tensor) – torch.Tensorshape
lambd (
Tensor) – torch.Tensorshape
- Return type:
Tensor- Returns:
torch.Tensor
- cg(x, y, sensitivity_map, sampling_mask, lambd, z)[source]#
Computes the conjugate gradient algorithm.
- Parameters:
x (
Tensor) – torch.Tensorfor (Guess) – math:x_0 of shape (N, height, width, complex=2).
y (
Tensor) – torch.Tensorshape (Denoised input of)
sensitivity_map (
Tensor) – torch.Tensorshape
sampling_mask (
Tensor) – torch.Tensorshape
lambd (
Tensor) – torch.Tensorshape
z (
Tensor) – torch.Tensorshape
- Return type:
Tensor- Returns:
torch.Tensor x_K.
- forward(masked_kspace, sensitivity_map, sampling_mask, z, lambd)[source]#
Performs forward pass of
ConjGrad.- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Prediction of image of)
sensitivity_map (
Tensor) – torch.Tensorshape
sampling_mask (
Tensor) – torch.Tensorshape
z (
Tensor) – torch.Tensorshape
lambd (
Tensor) – torch.TensorRegularaziation (trainable or not) parameter of shape (1)
- Return type:
Tensor- Returns:
torch.Tensor
direct.nn.conjgradnet.conjgradnet module#
- class direct.nn.conjgradnet.conjgradnet.ConjGradNet(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-7, cg_param_update_type=CGUpdateType.FR, **kwargs)[source]#
Bases:
ModuleConjugate Gradient Network for MRI Reconstruction.
Solves iteratively the following:
\[z^i = \arg \min_{z} \mu ||x^i - z||_2^2 + \mathcal{R}(z) x^i+1 = \arg \min_{x} ||A(x) - y||_2^2 + \mu ||x - z^i||_2^2\]where A is the forward operator of Accelerated MRI Reconstruction. The former equation is solved by a denoiser \(D_{i_\theta}\) who takes as input \(x^i\) and the former is solved by the conjugate gradient algorithm [1]_.
References: .. [1] Jonathan Richard Shewchuk (1994) An introduction to the conjugate gradient method without the agonizing pain.
- __init__(forward_operator, backward_operator, num_steps, denoiser_architecture=ModelName.RESNET, image_init=InitType.SENSE, no_parameter_sharing=True, cg_iters=15, cg_tol=1e-7, cg_param_update_type=CGUpdateType.FR, **kwargs)[source]#
Inits
ConjGradNet.- Parameters:
forward_operator (
Callable) – CallableOperator. (Backward)
backward_operator (
Callable) – CallableOperator.
num_steps (
int) – intsteps. (Number of unrolled optimization)
denoiser_architecture (Key word arguments should include denoiser architecture parameters. For example if) – ModelName
"resnet" (Type of architecture to use as a denoiser. Can be)
"unet"
"normunet"
"conv". ("didn" or)
Default (Convergence tolerance for conjugate gradient.) – “resnet”.
image_init (
InitType) – InitType"sense" (Initialization type for z. Can be) – “zeros”.
Default – “zeros”.
no_parameter_sharing (
bool) – boolFalse (If) – True.
Default – True.
cg_iters (
int) – intDefualt (Number of maximum conjugate gradient iterations.) –
cg_tol (
float) – floatDefault – 1e-7.
cg_param_update_type (
CGUpdateType) – CGUpdateTypecompute (How to) – math:b_k in conjugate gradient. Can be “FR”, “PRP”, “DY” and “BAN”. Default “FR”.
kwargs – dictionary
denoiser_architecture
unet_num_filters (is "unet" or "norm_unet" then)
unet_dropout_probability (unet_num_pool_layers and)
passed. (should be)
- forward(masked_kspace, sensitivity_map, sampling_mask)[source]#
Computes forward pass of
ConjGradNet.- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Sensitivity map of)
sensitivity_map (
Tensor) – torch.Tensorshape – None.
sampling_mask (
Tensor) – torch.Tensor
- Returns:
torch.Tensor Output image of shape (N, height, width, complex=2).
- Return type:
image
direct.nn.conjgradnet.conjgradnet_engine module#
- class direct.nn.conjgradnet.conjgradnet_engine.ConjGradNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
MRIModelEngine- __init__(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Inits
ConjGradNetEngine.- Parameters:
cfg (
BaseConfig) – Configuration file.model (
Module) – Model.device (
str) – Device. Can be “cuda: {idx}” or “cpu”.forward_operator (
Optional[Callable]) – The forward operator. Default: None.backward_operator (
Optional[Callable]) – The backward operator. Default: None.mixed_precision (
bool) – Use mixed precision. Default: False.**models (
Module) – Additional models.