Source code for cpu.hooks.checkpoint_hook

import os
import os.path as osp
from typing import Any, Dict, List, Optional

from .hookbase import HookBase


[docs] class CheckpointHook(HookBase): """Save checkpoint periodically. Save checkpoint, if current epoch / iteration is a multiple of ``period`` or ``max_epochs`` / ``max_iters`` is reached. Args: period (int): Save checkpoint every ``period`` epochs. max_to_keep (int): Maximum number of most current checkpoints to keep, previous checkpoints will be deleted. If None, save all checkpoints. """ def __init__(self, period: int, max_to_keep: Optional[int] = None) -> None: self._period = period assert max_to_keep is None or max_to_keep > 0 self._max_to_keep = max_to_keep self._recent_checkpoints: List[str] = []
[docs] def after_iter(self) -> None: if self.trainer.train_by_epoch: return if self.every_n_iters(self._period) or self.is_last_iter(): iter = self.trainer.cur_iter # ranged in [0, max_iters - 1] checkpoint_name = f"iter_{iter}.pth" self.trainer.save_checkpoint(checkpoint_name) self._delete_oldest_checkpoint(checkpoint_name)
[docs] def after_epoch(self) -> None: if not self.trainer.train_by_epoch: return if self.every_n_epochs(self._period) or self.is_last_epoch(): epoch = self.trainer.cur_epoch # ranged in [0, max_epochs - 1] checkpoint_name = f"epoch_{epoch}.pth" self.trainer.save_checkpoint(checkpoint_name) self._delete_oldest_checkpoint(checkpoint_name)
def _delete_oldest_checkpoint(self, checkpoint_name): if self._max_to_keep is not None: self._recent_checkpoints.append(checkpoint_name) if len(self._recent_checkpoints) > self._max_to_keep: file_name = self._recent_checkpoints.pop(0) file_path = osp.join(self.trainer.ckpt_dir, file_name) if os.path.exists(file_path): os.remove(file_path) def state_dict(self) -> Dict[str, Any]: return {key: value for key, value in self.__dict__.items() if key != "trainer"} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.__dict__.update(state_dict)