from pathlib import Path import torch import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler def empty_cuda_cache(_): torch.cuda.empty_cache() import gc gc.collect() def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, to_save=None, metrics_to_print=None, end_event=None): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan 2. BasicTimeProfiler 3. Print 4. Checkpoint :param trainer: :param config: :param stop_on_nan: :param clear_cuda_cache: :param use_profiler: :param to_save: :param metrics_to_print: :param end_event: :return: """ # if train_sampler is not None: # if not isinstance(train_sampler, DistributedSampler): # raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method") # # @trainer.on(Events.EPOCH_STARTED) # def distrib_set_epoch(engine): # train_sampler.set_epoch(engine.state.epoch - 1) @trainer.on(Events.STARTED) @idist.one_rank_only() def print_dataloader_size(engine): engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if use_profiler: # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED) @idist.one_rank_only() def log_intermediate_results(): profiler.print_results(profiler.get_results()) print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED ProgressBar(ncols=0).attach(trainer, "all") if metrics_to_print is not None: @trainer.on(print_interval_event) def print_interval(engine): print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t" for m in metrics_to_print: if m not in engine.state.metrics: continue print_str += f"{m}={engine.state.metrics[m]:.3f} " engine.logger.debug(print_str) if to_save is not None: checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), n_saved=config.checkpoint.n_saved, filename_prefix=config.name) if config.resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): checkpoint_path = Path(config.resume_from) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, checkpoint_handler) if end_event is not None: @trainer.on(end_event) def terminate(engine): engine.terminate() def setup_tensorboard_handler(trainer: Engine, config, output_transform): if config.interval.tensorboard is None: return None if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"), event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform), event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar)) @trainer.on(Events.COMPLETED) @idist.one_rank_only() def _(): # We need to close the logger with we are done tb_logger.close() return tb_logger return None