61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
from pathlib import Path
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
|
|
import ignite
|
|
import ignite.distributed as idist
|
|
from ignite.utils import manual_seed, setup_logger
|
|
|
|
import fire
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
def log_basic_info(logger, config):
|
|
logger.info(f"Train {config.name}")
|
|
logger.info(f"- PyTorch version: {torch.__version__}")
|
|
logger.info(f"- Ignite version: {ignite.__version__}")
|
|
if idist.get_world_size() > 1:
|
|
logger.info("Distributed setting:\n")
|
|
idist.show_config()
|
|
|
|
|
|
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
|
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
|
|
log_basic_info(logger, config)
|
|
|
|
if setup_random_seed:
|
|
manual_seed(config.misc.random_seed + idist.get_rank())
|
|
if setup_output_dir:
|
|
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
|
config.output_dir = str(output_dir)
|
|
if idist.get_rank() == 0:
|
|
if not output_dir.exists():
|
|
output_dir.mkdir(parents=True)
|
|
logger.info(f"mkdir -p {output_dir}")
|
|
logger.info(f"output path: {config.output_dir}")
|
|
if backup_config:
|
|
with open(output_dir / "config.yml", "w+") as f:
|
|
print(config.pretty(), file=f)
|
|
|
|
OmegaConf.set_readonly(config, True)
|
|
|
|
engine = import_module(f"engine.{config.engine}")
|
|
engine.run(task, config, logger)
|
|
|
|
|
|
def run(task, config: str, *omega_options, **kwargs):
|
|
omega_options = [str(o) for o in omega_options]
|
|
conf = OmegaConf.merge(OmegaConf.load(config), OmegaConf.from_cli(omega_options))
|
|
backend = kwargs.get("backend", "nccl")
|
|
backup_config = kwargs.get("backup_config", False)
|
|
setup_output_dir = kwargs.get("setup_output_dir", False)
|
|
setup_random_seed = kwargs.get("setup_random_seed", False)
|
|
with idist.Parallel(backend=backend) as parallel:
|
|
parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir,
|
|
setup_random_seed=setup_random_seed)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fire.Fire(run)
|