from pathlib import Path import torch import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler def setup_common_handlers( trainer: Engine, output_dir=None, stop_on_nan=True, use_profiler=True, print_interval_event=None, metrics_to_print=None, to_save=None, resume_from=None, save_interval_event=None, **checkpoint_kwargs ): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan 2. BasicTimeProfiler 3. Print 4. Checkpoint :param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary or sequence or a single tensor. :param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually :param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer. :param use_profiler: :param print_interval_event: :param metrics_to_print: :param to_save: :param resume_from: :param save_interval_event: :param checkpoint_kwargs: :return: """ @trainer.on(Events.STARTED) @idist.one_rank_only() def print_dataloader_size(engine): engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if use_profiler: # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED(once=1)) @idist.one_rank_only() def log_intermediate_results(): profiler.print_results(profiler.get_results()) @trainer.on(Events.COMPLETED) @idist.one_rank_only() def _(): profiler.print_results(profiler.get_results()) # profiler.write_results(f"{output_dir}/time_profiling.csv") if metrics_to_print is not None: if print_interval_event is None: raise ValueError( "If metrics_to_print argument is provided then print_interval_event arguments should be also defined" ) @trainer.on(print_interval_event) def print_interval(engine): print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" for m in metrics_to_print: if m not in engine.state.metrics: continue print_str += f"{m}={engine.state.metrics[m]:.3f} " engine.logger.info(print_str) if to_save is not None: checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False), **checkpoint_kwargs) if resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): checkpoint_path = Path(resume_from) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") if save_interval_event is not None: trainer.add_event_handler(save_interval_event, checkpoint_handler)