direct.nn.lpd package#
Submodules#
direct.nn.lpd.config module#
- class direct.nn.lpd.config.LPDNetConfig(model_name='???', engine_name=None, num_iter=25, num_primal=5, num_dual=5, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', primal_mwcnn_hidden_channels=16, primal_mwcnn_num_scales=4, primal_mwcnn_bias=True, primal_mwcnn_batchnorm=False, primal_unet_num_filters=8, primal_unet_num_pool_layers=4, primal_unet_dropout_probability=0.0, dual_conv_hidden_channels=16, dual_conv_n_convs=4, dual_conv_batchnorm=False, dual_didn_hidden_channels=64, dual_didn_num_dubs=6, dual_didn_num_convs_recon=9, dual_unet_num_filters=8, dual_unet_num_pool_layers=4, dual_unet_dropout_probability=0.0)[source]#
Bases:
ModelConfig- num_iter = 25#
- num_primal = 5#
- num_dual = 5#
- primal_model_architecture = 'MWCNN'#
- dual_model_architecture = 'DIDN'#
- primal_mwcnn_num_scales = 4#
- primal_mwcnn_bias = True#
- primal_mwcnn_batchnorm = False#
- primal_unet_num_filters = 8#
- primal_unet_num_pool_layers = 4#
- primal_unet_dropout_probability = 0.0#
- dual_conv_n_convs = 4#
- dual_conv_batchnorm = False#
- dual_didn_num_dubs = 6#
- dual_didn_num_convs_recon = 9#
- dual_unet_num_filters = 8#
- dual_unet_num_pool_layers = 4#
- dual_unet_dropout_probability = 0.0#
- __init__(model_name='???', engine_name=None, num_iter=25, num_primal=5, num_dual=5, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', primal_mwcnn_hidden_channels=16, primal_mwcnn_num_scales=4, primal_mwcnn_bias=True, primal_mwcnn_batchnorm=False, primal_unet_num_filters=8, primal_unet_num_pool_layers=4, primal_unet_dropout_probability=0.0, dual_conv_hidden_channels=16, dual_conv_n_convs=4, dual_conv_batchnorm=False, dual_didn_hidden_channels=64, dual_didn_num_dubs=6, dual_didn_num_convs_recon=9, dual_unet_num_filters=8, dual_unet_num_pool_layers=4, dual_unet_dropout_probability=0.0)#
direct.nn.lpd.lpd module#
- class direct.nn.lpd.lpd.DualNet(num_dual, **kwargs)[source]#
Bases:
ModuleDual Network for Learned Primal Dual Network.
- __init__(num_dual, **kwargs)[source]#
Inits
DualNet.- Parameters:
num_dual (
int) – Number of dual for LPD algorithm.**kwargs – Keyword arguments.
- static compute_model_per_coil(model, data)[source]#
Computes model per coil.
- Parameters:
model (
Module) – Model to compute.data (
Tensor) – Multi-coil input.
- Return type:
Tensor- Returns:
Multi-coil output.
- forward(h, forward_f, g)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class direct.nn.lpd.lpd.PrimalNet(num_primal, **kwargs)[source]#
Bases:
ModulePrimal Network for Learned Primal Dual Network.
- __init__(num_primal, **kwargs)[source]#
Inits
PrimalNet.- Parameters:
num_primal (
int) – Number of primal for LPD algorithm.**kwargs – Keyword arguments.
- forward(f, backward_h)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class direct.nn.lpd.lpd.LPDNet(forward_operator, backward_operator, num_iter, num_primal, num_dual, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', **kwargs)[source]#
Bases:
ModuleLearned Primal Dual network implementation inspired by [1].
References:
- __init__(forward_operator, backward_operator, num_iter, num_primal, num_dual, primal_model_architecture='MWCNN', dual_model_architecture='DIDN', **kwargs)[source]#
Inits
LPDNet.- Parameters:
forward_operator (
Callable) – CallableOperator. (Backward)
backward_operator (
Callable) – CallableOperator.
num_iter (
int) – intiterations. (Number of unrolled)
num_primal (
int) – intnetworks. (Number of dual)
num_dual (
int) – intnetworks.
primal_model_architecture (
str) – strand (Dual model architecture. Currently only implemented for CONV and DIDN) – ‘MWCNN’.
dual_model_architecture (
str) – strand – ‘DIDN’.
kwargs – dict
architectures. (Keyword arguments for model)
- forward(masked_kspace, sensitivity_map, sampling_mask)[source]#
Computes forward pass of
LPDNet.- Parameters:
masked_kspace (
Tensor) – torch.Tensorshape (Sampling mask of)
sensitivity_map (
Tensor) – torch.Tensorshape
sampling_mask (
Tensor) – torch.Tensorshape
- Returns:
torch.Tensor Output image of shape (N, height, width, complex=2).
- Return type:
output
direct.nn.lpd.lpd_engine module#
- class direct.nn.lpd.lpd_engine.LPDNetEngine(cfg, model, device, forward_operator=None, backward_operator=None, mixed_precision=False, **models)[source]#
Bases:
MRIModelEngineLPDNet Engine.