add set_epoch methods
This commit is contained in:
parent
1e7f63cf85
commit
9dfb887c86
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@ -34,13 +35,11 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
:return:
|
||||
"""
|
||||
|
||||
# if train_sampler is not None:
|
||||
# if not isinstance(train_sampler, DistributedSampler):
|
||||
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
|
||||
#
|
||||
# @trainer.on(Events.EPOCH_STARTED)
|
||||
# def distrib_set_epoch(engine):
|
||||
# train_sampler.set_epoch(engine.state.epoch - 1)
|
||||
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def distrib_set_epoch(engine):
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
@ -91,7 +90,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler)
|
||||
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||
if end_event is not None:
|
||||
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||
|
||||
@trainer.on(end_event)
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user