add set_epoch methods

This commit is contained in:
budui 2020-08-23 20:34:16 +08:00
parent 1e7f63cf85
commit 9dfb887c86

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
@ -34,13 +35,11 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
:return: :return:
""" """
# if train_sampler is not None: if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
# if not isinstance(train_sampler, DistributedSampler): @trainer.on(Events.EPOCH_STARTED)
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method") def distrib_set_epoch(engine):
# trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
# @trainer.on(Events.EPOCH_STARTED) trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
# def distrib_set_epoch(engine):
# train_sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED) @trainer.on(Events.STARTED)
@idist.one_rank_only() @idist.one_rank_only()
@ -91,7 +90,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.logger.info(f"resume from a checkpoint {checkpoint_path}") engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler) checkpoint_handler)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None: if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event) @trainer.on(end_event)
def terminate(engine): def terminate(engine):
engine.terminate() engine.terminate()