Source code for direct.data.lr_scheduler
# coding=utf-8
# Copyright (c) DIRECT Contributors
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Taken from Detectron 2, licensed under Apache 2.0.
# https://github.com/facebookresearch/detectron2/blob/60d7a1fd33cc48e58968659cd3301f3300b2786b/detectron2/solver/lr_scheduler.py
# Changes:
# - Docstring to match the rest of the library.
# - Calls to other subroutines which do not exist in DIRECT.
# - Stylistic changes.
import logging
import math
from bisect import bisect_right
from typing import List
import torch
# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
# only on epoch boundaries. We typically use iteration based schedules instead.
# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean
# "iteration" instead.
# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating
# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it.
[docs]
class LRScheduler(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access
def __init__(self, optimizer, last_epoch=-1, verbose=False):
super().__init__(optimizer, last_epoch, verbose)
self.logger = logging.getLogger(type(self).__name__)
[docs]
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which is not the optimizer or logger.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ["optimizer", "logger"]}
return state_dict
[docs]
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access
def __init__(
self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
warmup_factor: float = 0.001,
warmup_iterations: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {milestones}",
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iterations = warmup_iterations
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
[docs]
def get_lr(self) -> List[float]: # type: ignore
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method,
self.last_epoch, # type: ignore
self.warmup_iterations,
self.warmup_factor,
)
return [
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) # type: ignore
for base_lr in self.base_lrs # type: ignore
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
[docs]
class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access
def __init__(
self,
optimizer: torch.optim.Optimizer,
max_iters: int,
warmup_factor: float = 0.001,
warmup_iterations: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
self.max_iters = max_iters
self.warmup_factor = warmup_factor
self.warmup_iterations = warmup_iterations
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
[docs]
def get_lr(self) -> List[float]: # type: ignore
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method,
self.last_epoch, # type: ignore
self.warmup_iterations,
self.warmup_factor,
)
# Different definitions of half-cosine with warmup are possible. For
# simplicity we multiply the standard half-cosine schedule by the warmup
# factor. An alternative is to start the period of the cosine at warmup_iterations
# instead of at 0. In the case that warmup_iterations << max_iters the two are
# very close to each other.
return [
base_lr * warmup_factor * 0.5 * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) # type: ignore
for base_lr in self.base_lrs # type: ignore
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
def _get_warmup_factor_at_iter(method: str, curr_iter: int, warmup_iters: int, warmup_factor: float) -> float:
"""Return the learning rate warmup factor at a specific iteration.
Parameters
----------
method: str
Warmup method; either "constant" or "linear".
curr_iter: int
Iteration at which to calculate the warmup factor.
warmup_iters: int
The length of the warmup phases.
warmup_factor: float
The base warmup factor (the meaning changes according to the method used).
Returns
-------
float: The effective warmup factor at the given iteration.
"""
if curr_iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
if method == "linear":
alpha = curr_iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
raise ValueError(f"Unknown warmup method: {method}")