Source code for direct.config.defaults

# Copyright 2025 AI for Oncology Research Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, List, Optional

from omegaconf import MISSING

from direct.config import BaseConfig
from direct.data.datasets_config import DatasetConfig


[docs] @dataclass class TensorboardConfig(BaseConfig): num_images: int = 8
[docs] @dataclass class LoggingConfig(BaseConfig): log_as_image: Optional[List[str]] = None tensorboard: TensorboardConfig = field(default_factory=TensorboardConfig)
[docs] @dataclass class FunctionConfig(BaseConfig): function: str = MISSING multiplier: float = 1.0
[docs] @dataclass class CheckpointerConfig(BaseConfig): checkpoint_steps: int = 500
[docs] @dataclass class LossConfig(BaseConfig): crop: Optional[str] = None losses: List[Any] = field(default_factory=lambda: [FunctionConfig()])
[docs] @dataclass class TrainingConfig(BaseConfig): # Dataset datasets: List[Any] = field(default_factory=lambda: [DatasetConfig()]) # model_checkpoint gives the checkpoint from which we can load the *model* weights. model_checkpoint: Optional[str] = None # Optimizer optimizer: str = "Adam" lr: float = 5e-4 weight_decay: float = 1e-6 batch_size: int = 2 # LR Scheduler lr_step_size: int = 5000 lr_gamma: float = 0.5 lr_warmup_iter: int = 500 # Stochastic weight averaging swa_start_iter: Optional[int] = None num_iterations: int = 50000 # Validation validation_steps: int = 1000 # Gradient gradient_steps: int = 1 gradient_clipping: float = 0.0 gradient_debug: bool = False # Loss loss: LossConfig = field(default_factory=LossConfig) # Checkpointer checkpointer: CheckpointerConfig = field(default_factory=CheckpointerConfig) # Metrics metrics: List[str] = field(default_factory=lambda: []) # Regularizers regularizers: List[str] = field(default_factory=lambda: [])
[docs] @dataclass class ValidationConfig(BaseConfig): datasets: List[Any] = field(default_factory=lambda: [DatasetConfig()]) batch_size: int = 8 metrics: List[str] = field(default_factory=lambda: []) regularizers: List[str] = field(default_factory=lambda: []) crop: Optional[str] = "training"
[docs] @dataclass class InferenceConfig(BaseConfig): dataset: DatasetConfig = field(default_factory=DatasetConfig) batch_size: int = 1 crop: Optional[str] = None
[docs] @dataclass class ModelConfig(BaseConfig): model_name: str = MISSING engine_name: Optional[str] = None
[docs] @dataclass class PhysicsConfig(BaseConfig): forward_operator: str = "fft2" backward_operator: str = "ifft2" use_noise_matrix: bool = False noise_matrix_scaling: Optional[float] = 1.0
[docs] @dataclass class DefaultConfig(BaseConfig): model: ModelConfig = MISSING additional_models: Optional[Any] = None physics: PhysicsConfig = field(default_factory=PhysicsConfig) training: TrainingConfig = field(default_factory=TrainingConfig) # This should be optional. validation: ValidationConfig = field(default_factory=ValidationConfig) # This should be optional. inference: Optional[InferenceConfig] = None logging: LoggingConfig = field(default_factory=LoggingConfig)