raycv/util/handler.py
2020-08-21 16:14:30 +08:00

96 lines
3.5 KiB
Python

from pathlib import Path
import torch
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler
def setup_common_handlers(
trainer: Engine,
output_dir=None,
stop_on_nan=True,
use_profiler=True,
print_interval_event=None,
metrics_to_print=None,
to_save=None,
resume_from=None,
save_interval_event=None,
**checkpoint_kwargs
):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
or sequence or a single tensor.
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
:param use_profiler:
:param print_interval_event:
:param metrics_to_print:
:param to_save:
:param resume_from:
:param save_interval_event:
:param checkpoint_kwargs:
:return:
"""
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.print_results(profiler.get_results())
# profiler.write_results(f"{output_dir}/time_profiling.csv")
if metrics_to_print is not None:
if print_interval_event is None:
raise ValueError(
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
)
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str)
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
**checkpoint_kwargs)
if resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None:
trainer.add_event_handler(save_interval_event, checkpoint_handler)