raycv/util/handler.py
2020-08-24 06:51:42 +08:00

110 lines
4.4 KiB
Python

from pathlib import Path
import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
def empty_cuda_cache(_):
torch.cuda.empty_cache()
import gc
gc.collect()
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer:
:param config:
:param stop_on_nan:
:param clear_cuda_cache:
:param use_profiler:
:param to_save:
:param end_event:
:param set_epoch_for_dist_sampler:
:return:
"""
if set_epoch_for_dist_sampler:
@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
def print_info(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
ProgressBar(ncols=0).attach(trainer, "all")
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(config.resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event)
def terminate(engine):
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return tb_logger
return None