Source code for direct.config.defaults

# Copyright (c) DIRECT Contributors

from dataclasses import dataclass, field
from typing import Any, List, Optional

from omegaconf import MISSING

from direct.config import BaseConfig
from direct.data.datasets_config import DatasetConfig


[docs] @dataclass class TensorboardConfig(BaseConfig): num_images: int = 8
[docs] @dataclass class LoggingConfig(BaseConfig): log_as_image: Optional[List[str]] = None tensorboard: TensorboardConfig = TensorboardConfig()
[docs] @dataclass class FunctionConfig(BaseConfig): function: str = MISSING multiplier: float = 1.0
[docs] @dataclass class CheckpointerConfig(BaseConfig): checkpoint_steps: int = 500
[docs] @dataclass class LossConfig(BaseConfig): crop: Optional[str] = None losses: List[Any] = field(default_factory=lambda: [FunctionConfig()])
[docs] @dataclass class TrainingConfig(BaseConfig): # Dataset datasets: List[Any] = field(default_factory=lambda: [DatasetConfig()]) # model_checkpoint gives the checkpoint from which we can load the *model* weights. model_checkpoint: Optional[str] = None # Optimizer optimizer: str = "Adam" lr: float = 5e-4 weight_decay: float = 1e-6 batch_size: int = 2 # LR Scheduler lr_step_size: int = 5000 lr_gamma: float = 0.5 lr_warmup_iter: int = 500 # Stochastic weight averaging swa_start_iter: Optional[int] = None num_iterations: int = 50000 # Validation validation_steps: int = 1000 # Gradient gradient_steps: int = 1 gradient_clipping: float = 0.0 gradient_debug: bool = False # Loss loss: LossConfig = LossConfig() # Checkpointer checkpointer: CheckpointerConfig = CheckpointerConfig() # Metrics metrics: List[str] = field(default_factory=lambda: []) # Regularizers regularizers: List[str] = field(default_factory=lambda: [])
[docs] @dataclass class ValidationConfig(BaseConfig): datasets: List[Any] = field(default_factory=lambda: [DatasetConfig()]) batch_size: int = 8 metrics: List[str] = field(default_factory=lambda: []) regularizers: List[str] = field(default_factory=lambda: []) crop: Optional[str] = "training"
[docs] @dataclass class InferenceConfig(BaseConfig): dataset: DatasetConfig = DatasetConfig() batch_size: int = 1 crop: Optional[str] = None
[docs] @dataclass class ModelConfig(BaseConfig): model_name: str = MISSING engine_name: Optional[str] = None
[docs] @dataclass class PhysicsConfig(BaseConfig): forward_operator: str = "fft2" backward_operator: str = "ifft2" use_noise_matrix: bool = False noise_matrix_scaling: Optional[float] = 1.0
[docs] @dataclass class DefaultConfig(BaseConfig): model: ModelConfig = MISSING additional_models: Optional[Any] = None physics: PhysicsConfig = PhysicsConfig() training: TrainingConfig = TrainingConfig() # This should be optional. validation: ValidationConfig = ValidationConfig() # This should be optional. inference: Optional[InferenceConfig] = None logging: LoggingConfig = LoggingConfig()