add set_epoch methods
This commit is contained in:
parent
1e7f63cf85
commit
9dfb887c86
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.engine import Events, Engine
|
from ignite.engine import Events, Engine
|
||||||
@ -34,13 +35,11 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if train_sampler is not None:
|
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||||
# if not isinstance(train_sampler, DistributedSampler):
|
@trainer.on(Events.EPOCH_STARTED)
|
||||||
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
|
def distrib_set_epoch(engine):
|
||||||
#
|
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||||
# @trainer.on(Events.EPOCH_STARTED)
|
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||||
# def distrib_set_epoch(engine):
|
|
||||||
# train_sampler.set_epoch(engine.state.epoch - 1)
|
|
||||||
|
|
||||||
@trainer.on(Events.STARTED)
|
@trainer.on(Events.STARTED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
@ -91,7 +90,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
|||||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||||
checkpoint_handler)
|
checkpoint_handler)
|
||||||
|
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||||
if end_event is not None:
|
if end_event is not None:
|
||||||
|
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||||
|
|
||||||
@trainer.on(end_event)
|
@trainer.on(end_event)
|
||||||
def terminate(engine):
|
def terminate(engine):
|
||||||
engine.terminate()
|
engine.terminate()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user