Source code for cpu.hooks.distributed_hook

from .hookbase import HookBase


[docs] class DistributedHook(HookBase): """Call :meth:`DistributedSampler.set_epoch` before each epoch."""
[docs] def before_epoch(self) -> None: data_loader = self.trainer.data_loader if hasattr(data_loader.sampler, "set_epoch"): data_loader.sampler.set_epoch(self.trainer.cur_epoch) elif hasattr(data_loader.batch_sampler.sampler, "set_epoch"): # batch sampler in PyTorch warps the sampler as its attributes data_loader.batch_sampler.sampler.set_epoch(self.trainer.cur_epoch)