direct.config package#
Submodules#
direct.config.defaults module#
- class direct.config.defaults.CheckpointerConfig(checkpoint_steps: int = 500)[source][source]#
Bases:
BaseConfig
-
checkpoint_steps:
int
= 500#
-
checkpoint_steps:
- class direct.config.defaults.DefaultConfig(model: direct.config.defaults.ModelConfig = '???', additional_models: Optional[Any] = None, physics: direct.config.defaults.PhysicsConfig = <factory>, training: direct.config.defaults.TrainingConfig = <factory>, validation: direct.config.defaults.ValidationConfig = <factory>, inference: Optional[direct.config.defaults.InferenceConfig] = None, logging: direct.config.defaults.LoggingConfig = <factory>)[source][source]#
Bases:
BaseConfig
-
additional_models:
Optional
[Any
] = None#
-
inference:
Optional
[InferenceConfig
] = None#
-
logging:
LoggingConfig
#
-
model:
ModelConfig
= '???'#
-
physics:
PhysicsConfig
#
-
training:
TrainingConfig
#
-
validation:
ValidationConfig
#
-
additional_models:
- class direct.config.defaults.FunctionConfig(function: str = '???', multiplier: float = 1.0)[source][source]#
Bases:
BaseConfig
-
function:
str
= '???'#
-
multiplier:
float
= 1.0#
-
function:
- class direct.config.defaults.InferenceConfig(dataset: direct.data.datasets_config.DatasetConfig = <factory>, batch_size: int = 1, crop: Optional[str] = None)[source][source]#
Bases:
BaseConfig
-
batch_size:
int
= 1#
-
crop:
Optional
[str
] = None#
-
dataset:
DatasetConfig
#
-
batch_size:
- class direct.config.defaults.LoggingConfig(log_as_image: Optional[List[str]] = None, tensorboard: direct.config.defaults.TensorboardConfig = <factory>)[source][source]#
Bases:
BaseConfig
-
log_as_image:
Optional
[List
[str
]] = None#
-
tensorboard:
TensorboardConfig
#
-
log_as_image:
- class direct.config.defaults.LossConfig(crop: Optional[str] = None, losses: List[Any] = <factory>)[source][source]#
Bases:
BaseConfig
-
crop:
Optional
[str
] = None#
-
losses:
List
[Any
]#
-
crop:
- class direct.config.defaults.ModelConfig(model_name: str = '???', engine_name: str | None = None)[source][source]#
Bases:
BaseConfig
-
engine_name:
Optional
[str
] = None#
-
model_name:
str
= '???'#
-
engine_name:
- class direct.config.defaults.PhysicsConfig(forward_operator: str = 'fft2', backward_operator: str = 'ifft2', use_noise_matrix: bool = False, noise_matrix_scaling: float | None = 1.0)[source][source]#
Bases:
BaseConfig
-
backward_operator:
str
= 'ifft2'#
-
forward_operator:
str
= 'fft2'#
-
noise_matrix_scaling:
Optional
[float
] = 1.0#
-
use_noise_matrix:
bool
= False#
-
backward_operator:
- class direct.config.defaults.TensorboardConfig(num_images: int = 8)[source][source]#
Bases:
BaseConfig
-
num_images:
int
= 8#
-
num_images:
- class direct.config.defaults.TrainingConfig(datasets: List[Any] = <factory>, model_checkpoint: Optional[str] = None, optimizer: str = 'Adam', lr: float = 0.0005, weight_decay: float = 1e-06, batch_size: int = 2, lr_step_size: int = 5000, lr_gamma: float = 0.5, lr_warmup_iter: int = 500, swa_start_iter: Optional[int] = None, num_iterations: int = 50000, validation_steps: int = 1000, gradient_steps: int = 1, gradient_clipping: float = 0.0, gradient_debug: bool = False, loss: direct.config.defaults.LossConfig = <factory>, checkpointer: direct.config.defaults.CheckpointerConfig = <factory>, metrics: List[str] = <factory>, regularizers: List[str] = <factory>)[source][source]#
Bases:
BaseConfig
-
batch_size:
int
= 2#
-
checkpointer:
CheckpointerConfig
#
-
datasets:
List
[Any
]#
-
gradient_clipping:
float
= 0.0#
-
gradient_debug:
bool
= False#
-
gradient_steps:
int
= 1#
-
loss:
LossConfig
#
-
lr:
float
= 0.0005#
-
lr_gamma:
float
= 0.5#
-
lr_step_size:
int
= 5000#
-
lr_warmup_iter:
int
= 500#
-
metrics:
List
[str
]#
-
model_checkpoint:
Optional
[str
] = None#
-
num_iterations:
int
= 50000#
-
optimizer:
str
= 'Adam'#
-
regularizers:
List
[str
]#
-
swa_start_iter:
Optional
[int
] = None#
-
validation_steps:
int
= 1000#
-
weight_decay:
float
= 1e-06#
-
batch_size:
- class direct.config.defaults.ValidationConfig(datasets: List[Any] = <factory>, batch_size: int = 8, metrics: List[str] = <factory>, regularizers: List[str] = <factory>, crop: Optional[str] = 'training')[source][source]#
Bases:
BaseConfig
-
batch_size:
int
= 8#
-
crop:
Optional
[str
] = 'training'#
-
datasets:
List
[Any
]#
-
metrics:
List
[str
]#
-
regularizers:
List
[str
]#
-
batch_size: