direct.config package

Contents

direct.config package#

Submodules#

direct.config.defaults module#

class direct.config.defaults.CheckpointerConfig(checkpoint_steps: int = 500)[source][source]#

Bases: BaseConfig

checkpoint_steps: int = 500#
class direct.config.defaults.DefaultConfig(model: direct.config.defaults.ModelConfig = '???', additional_models: Optional[Any] = None, physics: direct.config.defaults.PhysicsConfig = <factory>, training: direct.config.defaults.TrainingConfig = <factory>, validation: direct.config.defaults.ValidationConfig = <factory>, inference: Optional[direct.config.defaults.InferenceConfig] = None, logging: direct.config.defaults.LoggingConfig = <factory>)[source][source]#

Bases: BaseConfig

additional_models: Optional[Any] = None#
inference: Optional[InferenceConfig] = None#
logging: LoggingConfig#
model: ModelConfig = '???'#
physics: PhysicsConfig#
training: TrainingConfig#
validation: ValidationConfig#
class direct.config.defaults.FunctionConfig(function: str = '???', multiplier: float = 1.0)[source][source]#

Bases: BaseConfig

function: str = '???'#
multiplier: float = 1.0#
class direct.config.defaults.InferenceConfig(dataset: direct.data.datasets_config.DatasetConfig = <factory>, batch_size: int = 1, crop: Optional[str] = None)[source][source]#

Bases: BaseConfig

batch_size: int = 1#
crop: Optional[str] = None#
dataset: DatasetConfig#
class direct.config.defaults.LoggingConfig(log_as_image: Optional[List[str]] = None, tensorboard: direct.config.defaults.TensorboardConfig = <factory>)[source][source]#

Bases: BaseConfig

log_as_image: Optional[List[str]] = None#
tensorboard: TensorboardConfig#
class direct.config.defaults.LossConfig(crop: Optional[str] = None, losses: List[Any] = <factory>)[source][source]#

Bases: BaseConfig

crop: Optional[str] = None#
losses: List[Any]#
class direct.config.defaults.ModelConfig(model_name: str = '???', engine_name: str | None = None)[source][source]#

Bases: BaseConfig

engine_name: Optional[str] = None#
model_name: str = '???'#
class direct.config.defaults.PhysicsConfig(forward_operator: str = 'fft2', backward_operator: str = 'ifft2', use_noise_matrix: bool = False, noise_matrix_scaling: float | None = 1.0)[source][source]#

Bases: BaseConfig

backward_operator: str = 'ifft2'#
forward_operator: str = 'fft2'#
noise_matrix_scaling: Optional[float] = 1.0#
use_noise_matrix: bool = False#
class direct.config.defaults.TensorboardConfig(num_images: int = 8)[source][source]#

Bases: BaseConfig

num_images: int = 8#
class direct.config.defaults.TrainingConfig(datasets: List[Any] = <factory>, model_checkpoint: Optional[str] = None, optimizer: str = 'Adam', lr: float = 0.0005, weight_decay: float = 1e-06, batch_size: int = 2, lr_step_size: int = 5000, lr_gamma: float = 0.5, lr_warmup_iter: int = 500, swa_start_iter: Optional[int] = None, num_iterations: int = 50000, validation_steps: int = 1000, gradient_steps: int = 1, gradient_clipping: float = 0.0, gradient_debug: bool = False, loss: direct.config.defaults.LossConfig = <factory>, checkpointer: direct.config.defaults.CheckpointerConfig = <factory>, metrics: List[str] = <factory>, regularizers: List[str] = <factory>)[source][source]#

Bases: BaseConfig

batch_size: int = 2#
checkpointer: CheckpointerConfig#
datasets: List[Any]#
gradient_clipping: float = 0.0#
gradient_debug: bool = False#
gradient_steps: int = 1#
loss: LossConfig#
lr: float = 0.0005#
lr_gamma: float = 0.5#
lr_step_size: int = 5000#
lr_warmup_iter: int = 500#
metrics: List[str]#
model_checkpoint: Optional[str] = None#
num_iterations: int = 50000#
optimizer: str = 'Adam'#
regularizers: List[str]#
swa_start_iter: Optional[int] = None#
validation_steps: int = 1000#
weight_decay: float = 1e-06#
class direct.config.defaults.ValidationConfig(datasets: List[Any] = <factory>, batch_size: int = 8, metrics: List[str] = <factory>, regularizers: List[str] = <factory>, crop: Optional[str] = 'training')[source][source]#

Bases: BaseConfig

batch_size: int = 8#
crop: Optional[str] = 'training'#
datasets: List[Any]#
metrics: List[str]#
regularizers: List[str]#

Module contents#

class direct.config.BaseConfig[source][source]#

Bases: object