112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
import ignite.distributed as idist
|
|
from ignite.engine import Events, Engine
|
|
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
|
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
|
|
|
|
|
|
def empty_cuda_cache(_):
|
|
torch.cuda.empty_cache()
|
|
import gc
|
|
|
|
gc.collect()
|
|
|
|
|
|
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
|
to_save=None, metrics_to_print=None, end_event=None):
|
|
"""
|
|
Helper method to setup trainer with common handlers.
|
|
1. TerminateOnNan
|
|
2. BasicTimeProfiler
|
|
3. Print
|
|
4. Checkpoint
|
|
:param trainer:
|
|
:param config:
|
|
:param stop_on_nan:
|
|
:param clear_cuda_cache:
|
|
:param use_profiler:
|
|
:param to_save:
|
|
:param metrics_to_print:
|
|
:param end_event:
|
|
:return:
|
|
"""
|
|
|
|
@trainer.on(Events.STARTED)
|
|
@idist.one_rank_only()
|
|
def print_dataloader_size(engine):
|
|
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
|
|
|
if stop_on_nan:
|
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
|
|
|
if torch.cuda.is_available() and clear_cuda_cache:
|
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
|
|
|
|
if use_profiler:
|
|
# Create an object of the profiler and attach an engine to it
|
|
profiler = BasicTimeProfiler()
|
|
profiler.attach(trainer)
|
|
|
|
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
|
|
@idist.one_rank_only()
|
|
def log_intermediate_results():
|
|
profiler.print_results(profiler.get_results())
|
|
|
|
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
|
|
|
|
ProgressBar(ncols=0).attach(trainer, "all")
|
|
|
|
if metrics_to_print is not None:
|
|
@trainer.on(print_interval_event)
|
|
def print_interval(engine):
|
|
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
|
for m in metrics_to_print:
|
|
if m not in engine.state.metrics:
|
|
continue
|
|
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
|
engine.logger.debug(print_str)
|
|
|
|
if to_save is not None:
|
|
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
|
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
|
if config.resume_from is not None:
|
|
@trainer.on(Events.STARTED)
|
|
def resume(engine):
|
|
checkpoint_path = Path(config.resume_from)
|
|
if not checkpoint_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
|
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
|
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
|
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
|
checkpoint_handler)
|
|
if end_event is not None:
|
|
@trainer.on(end_event)
|
|
def terminate(engine):
|
|
engine.terminate()
|
|
|
|
|
|
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
|
if config.interval.tensorboard is None:
|
|
return None
|
|
if idist.get_rank() == 0:
|
|
# Create a logger
|
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
|
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
|
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
|
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
|
|
|
@trainer.on(Events.COMPLETED)
|
|
@idist.one_rank_only()
|
|
def _():
|
|
# We need to close the logger with we are done
|
|
tb_logger.close()
|
|
|
|
return tb_logger
|
|
return None
|