raycv/engine/util/handler.py
2020-09-05 10:33:35 +08:00

155 lines
6.1 KiB
Python

from pathlib import Path
import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
def empty_cuda_cache(_):
torch.cuda.empty_cache()
import gc
gc.collect()
def step_transform_maker(stype: str, pairs_per_iteration=None):
assert stype in ["item", "iteration", "epoch"]
if stype == "item":
return lambda engine, _: engine.state.iteration * pairs_per_iteration
if stype == "iteration":
return lambda engine, _: engine.state.iteration
if stype == "epoch":
return lambda engine, _: engine.state.epoch
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer:
:param config:
:param stop_on_nan:
:param clear_cuda_cache:
:param use_profiler:
:param to_save:
:param end_event:
:param set_epoch_for_dist_sampler:
:return:
"""
if set_epoch_for_dist_sampler:
@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
@trainer.on(Events.EPOCH_COMPLETED(once=1))
def print_info(engine):
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
ProgressBar(ncols=0).attach(trainer, "all")
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(config.resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler
)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event)
def terminate(engine):
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
if config.handler.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
basic_event = Events.ITERATION_COMPLETED(
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="metric", metric_names="all",
global_step_transform=global_step_transform
),
event_name=basic_event
)
@trainer.on(basic_event)
def log_loss(engine):
global_step = global_step_transform(engine, None)
output_loss = engine.state.output["loss"]
for total_loss in output_loss:
if isinstance(output_loss[total_loss], dict):
for ln in output_loss[total_loss]:
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
else:
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
if isinstance(optimizers, dict):
for name in optimizers:
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
event_name=Events.ITERATION_STARTED
)
else:
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
event_name=Events.ITERATION_STARTED)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return tb_logger
return None