Source code for cpu.lr_scheduler

from typing import Any, Dict, List, Optional, Union

from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler


[docs] class LRWarmupScheduler: """This class wraps the standard PyTorch LR scheduler to support warmup. The usage is demonstrated in the following snippet: .. code-block:: python :emphasize-lines: 6-9 torch_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3) warmup_scheduler = LRWarmupScheduler(torch_scheduler) for epoch in range(max_epochs): for iter in range(epoch_len): train_one_iter() # call iter_update() after each iteration warmup_scheduler.iter_update() # call epoch_update() after each epoch warmup_scheduler.epoch_update() Args: torch_scheduler (_LRScheduler) by_epoch (bool): If True, the ``torch_scheduler`` is epoch-based, else iteration-based. Defaults to True. epoch_len (int): The number of iterations in an epoch. Required only when ``by_epoch=True & warmup_by_epoch=False``. warmup_t (int): How many iterations / epochs in warmup stage. If ``warmup_by_epoch=True``, "**t**" means epoch, else iteration. Defaults to 0 to disable warmup. warmup_by_epoch (bool): If True, perform warmup at each epoch end, else iteration end. Defaults to False. warmup_mode (str): "fix", "auto", or "factor". Defaults to "fix". warmup_init_lr (float): The initial warmup lr. Required in "fix" mode. Defaults to None. warmup_factor (float): The factor of initial warmup lr relative to base lr. Required in "auto" and "factor" mode. Defaults to None. """ def __init__( self, torch_scheduler: _LRScheduler, by_epoch: bool = True, epoch_len: Optional[int] = None, # the following settings are related to warmup warmup_t: int = 0, warmup_by_epoch: bool = False, warmup_mode: str = "fix", warmup_init_lr: Optional[float] = None, warmup_factor: Optional[float] = None, ): self.torch_scheduler = torch_scheduler self.by_epoch = by_epoch self.epoch_len = epoch_len self.warmup_t = warmup_t self.warmup_by_epoch = warmup_by_epoch self.warmup_mode = warmup_mode self.warmup_init_lr = warmup_init_lr self.warmup_factor = warmup_factor if warmup_by_epoch: assert by_epoch if by_epoch and warmup_t and not warmup_by_epoch: assert epoch_len is not None if self._is_plateau: assert by_epoch self.param_groups = self.torch_scheduler.optimizer.param_groups self.base_lrs = [param_group["lr"] for param_group in self.param_groups] if warmup_t: # pre-compute the regular lr if no warmup is performed max_t = warmup_t // epoch_len if by_epoch and not warmup_by_epoch else warmup_t self.regular_lrs_per_t = self._pre_compute_regular_lrs_per_t(max_t) self.last_iter = self.last_epoch = 0 self.in_iter_warmup = False if warmup_t > 0: if warmup_mode == "fix": assert isinstance(warmup_init_lr, float) self._set_lrs(warmup_init_lr) elif warmup_mode == "factor": assert isinstance(warmup_factor, float) self._set_lrs([base_lr * warmup_factor for base_lr in self.base_lrs]) elif warmup_mode == "auto": assert isinstance(warmup_factor, float) self.warmup_end_lrs = self.regular_lrs_per_t[-1] self._set_lrs([base_lr * warmup_factor for base_lr in self.base_lrs]) else: raise ValueError(f"Invalid warmup mode: {warmup_mode}") @property def _is_plateau(self) -> bool: return isinstance(self.torch_scheduler, ReduceLROnPlateau) def _pre_compute_regular_lrs_per_t(self, max_t: int) -> List[List[float]]: regular_lrs_per_t = [self.base_lrs] if self._is_plateau: return regular_lrs_per_t * (max_t + 1) for _ in range(max_t): self.torch_scheduler.step() regular_lrs_per_t.append([param_group["lr"] for param_group in self.param_groups]) return regular_lrs_per_t def _get_warmup_lrs(self, t: int, regular_lrs: List[float]) -> List[float]: alpha = t / self.warmup_t if self.warmup_mode == "fix": return [ self.warmup_init_lr * (1 - alpha) + base_lr * alpha for base_lr in self.base_lrs ] elif self.warmup_mode == "factor": factor = self.warmup_factor * (1 - alpha) + alpha return [lr * factor for lr in regular_lrs] else: return [ base_lr * self.warmup_factor * (1 - alpha) + end_lr * alpha for base_lr, end_lr in zip(self.base_lrs, self.warmup_end_lrs) ] def _set_lrs(self, lrs: Union[float, List[float]]) -> None: if not isinstance(lrs, (list, tuple)): lrs = [lrs] * len(self.param_groups) for param_group, lr in zip(self.param_groups, lrs): param_group["lr"] = lr
[docs] def epoch_update(self, metric: Optional[float] = None) -> None: """Prepare the learning rate for the next epoch. The method should be called after finishing each epoch. Args: metric (float): Metric value used in :class:`ReduceLROnPlateau`. Defaults to None. """ if not self.by_epoch: return self.last_epoch += 1 if self.warmup_by_epoch and self.last_epoch < self.warmup_t: self._set_lrs( self._get_warmup_lrs(self.last_epoch, self.regular_lrs_per_t[self.last_epoch])) elif self.warmup_by_epoch and self.last_epoch == self.warmup_t: self._set_lrs(self.regular_lrs_per_t[-1]) elif not self.in_iter_warmup: if self._is_plateau: self.torch_scheduler.step(metric) else: self.torch_scheduler.step()
[docs] def iter_update(self) -> None: """Prepare the learning rate for the next iteration. The method should be called after finishing each iteration. """ if self.warmup_by_epoch: return self.last_iter += 1 if self.last_iter < self.warmup_t: self.in_iter_warmup = True t = self.last_iter // self.epoch_len if self.by_epoch else self.last_iter self._set_lrs(self._get_warmup_lrs(self.last_iter, self.regular_lrs_per_t[t])) elif self.last_iter == self.warmup_t: self._set_lrs(self.regular_lrs_per_t[-1]) else: self.in_iter_warmup = False if not self.by_epoch: self.torch_scheduler.step()
[docs] def state_dict(self) -> Dict[str, Any]: """Returns the state of the scheduler as a dict.""" state = {key: value for key, value in self.__dict__.items() if key != "torch_scheduler"} state["torch_scheduler"] = self.torch_scheduler.state_dict() return state
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Loads the scheduler state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.torch_scheduler.load_state_dict(state_dict.pop("torch_scheduler")) self.__dict__.update(state_dict)