add set_epoch methods

This commit is contained in:
budui 2020-08-23 20:34:16 +08:00
parent 1e7f63cf85
commit 9dfb887c86

View File

@ -1,6 +1,7 @@
from pathlib import Path
import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist
from ignite.engine import Events, Engine
@ -34,13 +35,11 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
:return:
"""
# if train_sampler is not None:
# if not isinstance(train_sampler, DistributedSampler):
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
#
# @trainer.on(Events.EPOCH_STARTED)
# def distrib_set_epoch(engine):
# train_sampler.set_epoch(engine.state.epoch - 1)
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED)
@idist.one_rank_only()
@ -91,7 +90,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event)
def terminate(engine):
engine.terminate()