22 lines
799 B
Python
22 lines
799 B
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from ignite.engine import Engine
|
|
from ignite.handlers import Checkpoint
|
|
|
|
|
|
class Resumer:
|
|
def __init__(self, to_load, checkpoint_path):
|
|
self.to_load = to_load
|
|
if checkpoint_path is not None:
|
|
checkpoint_path = Path(checkpoint_path)
|
|
if not checkpoint_path.exists():
|
|
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
|
|
self.checkpoint_path = checkpoint_path
|
|
|
|
def __call__(self, engine: Engine):
|
|
if self.checkpoint_path is not None:
|
|
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
|
|
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
|
|
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
|