From 9dfb887c862742747161179bec8cce22fbf8ba7b Mon Sep 17 00:00:00 2001 From: budui Date: Sun, 23 Aug 2020 20:34:16 +0800 Subject: [PATCH] add set_epoch methods --- util/handler.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/util/handler.py b/util/handler.py index a211522..7c5e868 100644 --- a/util/handler.py +++ b/util/handler.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from torch.utils.data.distributed import DistributedSampler import ignite.distributed as idist from ignite.engine import Events, Engine @@ -34,13 +35,11 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ :return: """ - # if train_sampler is not None: - # if not isinstance(train_sampler, DistributedSampler): - # raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method") - # - # @trainer.on(Events.EPOCH_STARTED) - # def distrib_set_epoch(engine): - # train_sampler.set_epoch(engine.state.epoch - 1) + if isinstance(trainer.state.dataloader.sampler, DistributedSampler): + @trainer.on(Events.EPOCH_STARTED) + def distrib_set_epoch(engine): + trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") + trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) @trainer.on(Events.STARTED) @idist.one_rank_only() @@ -91,7 +90,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ engine.logger.info(f"resume from a checkpoint {checkpoint_path}") trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, checkpoint_handler) + trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically") if end_event is not None: + trainer.logger.debug(f"engine will stop on {end_event}") + @trainer.on(end_event) def terminate(engine): engine.terminate()