Source code for cpu.hooks.eval_hook

from typing import Callable

from .hookbase import HookBase


[docs] class EvalHook(HookBase): """Run an evaluation function periodically. It is executed every ``period`` epochs / iterations and after the last epoch / iteration. Args: period (int): The period to run ``eval_func``. Set to 0 to not evaluate periodically, but still after the last epoch / iteration. eval_func (callable): A function which takes no arguments, and returns a dict of evaluation metrics. """ priority = 1 def __init__(self, period: int, eval_func: Callable) -> None: self._period = period self._eval_func = eval_func def _do_eval(self) -> None: self.trainer.model.eval() res = self._eval_func() self.trainer.model.train() if res: assert isinstance(res, dict), f"Eval function must return a dict. Got {res} instead." for k, v in res.items(): try: v = float(v) except Exception as e: raise ValueError( f"Eval function should return a dict of float. Got '{k}: {v}' instead." ) from e train_by_epoch = self.trainer.train_by_epoch counter = self.trainer.cur_epoch if train_by_epoch else self.trainer.cur_iter self.log(counter, **res, smooth=False)
[docs] def after_epoch(self) -> None: if not self.trainer.train_by_epoch: return if self.every_n_epochs(self._period) or self.is_last_epoch(): self._do_eval()
[docs] def after_iter(self) -> None: if self.trainer.train_by_epoch: return if self.every_n_iters(self._period) or self.is_last_iter(): self._do_eval()