from pathlib import Path import torch from ignite.engine import Engine from ignite.handlers import Checkpoint class Resumer: def __init__(self, to_load, checkpoint_path): self.to_load = to_load if checkpoint_path is not None: checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise ValueError(f"Checkpoint '{checkpoint_path}' is not found") self.checkpoint_path = checkpoint_path def __call__(self, engine: Engine): if self.checkpoint_path is not None: ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")