from pathlib import Path import torch from torch.utils.data.distributed import DistributedSampler import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler def empty_cuda_cache(_): torch.cuda.empty_cache() import gc gc.collect() def step_transform_maker(stype: str, pairs_per_iteration=None): assert stype in ["item", "iteration", "epoch"] if stype == "item": return lambda engine, _: engine.state.iteration * pairs_per_iteration if stype == "iteration": return lambda engine, _: engine.state.iteration if stype == "epoch": return lambda engine, _: engine.state.epoch def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, to_save=None, end_event=None, set_epoch_for_dist_sampler=False): """ Helper method to setup trainer with common handlers. 1. TerminateOnNan 2. BasicTimeProfiler 3. Print 4. Checkpoint :param trainer: :param config: :param stop_on_nan: :param clear_cuda_cache: :param use_profiler: :param to_save: :param end_event: :param set_epoch_for_dist_sampler: :return: """ if set_epoch_for_dist_sampler: @trainer.on(Events.EPOCH_STARTED) def distrib_set_epoch(engine): if isinstance(trainer.state.dataloader.sampler, DistributedSampler): trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch") @trainer.on(Events.EPOCH_COMPLETED(once=1)) def print_info(engine): engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}") if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if use_profiler: # Create an object of the profiler and attach an engine to it profiler = BasicTimeProfiler() profiler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED) @idist.one_rank_only() def log_intermediate_results(): profiler.print_results(profiler.get_results()) ProgressBar(ncols=0).attach(trainer, "all") if to_save is not None: checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name) if config.resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): checkpoint_path = Path(config.resume_from) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") trainer.logger.info(f"load state_dict for {ckp.keys()}") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED, checkpoint_handler ) trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically") if end_event is not None: trainer.logger.debug(f"engine will stop on {end_event}") @trainer.on(end_event) def terminate(engine): engine.terminate() def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"): if config.handler.tensorboard is None: return None if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) tb_writer = tb_logger.writer pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() global_step_transform = step_transform_maker(step_type, pairs_per_iteration) basic_event = Events.ITERATION_COMPLETED( every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1)) tb_logger.attach( trainer, log_handler=OutputHandler( tag="metric", metric_names="all", global_step_transform=global_step_transform ), event_name=basic_event ) @trainer.on(basic_event) def log_loss(engine): global_step = global_step_transform(engine, None) output_loss = engine.state.output["loss"] for total_loss in output_loss: if isinstance(output_loss[total_loss], dict): for ln in output_loss[total_loss]: tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step) else: tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step) if isinstance(optimizers, dict): for name in optimizers: tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"), event_name=Events.ITERATION_STARTED ) else: tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"), event_name=Events.ITERATION_STARTED) @trainer.on(Events.COMPLETED) @idist.one_rank_only() def _(): # We need to close the logger with we are done tb_logger.close() return tb_logger return None