raycv/util/handler.py

112 lines
4.2 KiB
Python

from pathlib import Path
import torch
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
def empty_cuda_cache(_):
torch.cuda.empty_cache()
import gc
gc.collect()
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, metrics_to_print=None, end_event=None):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer:
:param config:
:param stop_on_nan:
:param clear_cuda_cache:
:param use_profiler:
:param to_save:
:param metrics_to_print:
:param end_event:
:return:
"""
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
ProgressBar(ncols=0).attach(trainer, "all")
if metrics_to_print is not None:
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.debug(print_str)
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(config.resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler)
if end_event is not None:
@trainer.on(end_event)
def terminate(engine):
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return tb_logger
return None