raycv/util/handler.py

22 lines
799 B
Python

from pathlib import Path
import torch
from ignite.engine import Engine
from ignite.handlers import Checkpoint
class Resumer:
def __init__(self, to_load, checkpoint_path):
self.to_load = to_load
if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
self.checkpoint_path = checkpoint_path
def __call__(self, engine: Engine):
if self.checkpoint_path is not None:
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")