from importlib import import_module from pathlib import Path import fire import ignite import ignite.distributed as idist import torch from ignite.utils import manual_seed from omegaconf import OmegaConf from util.misc import setup_logger def log_basic_info(logger, config): logger.info(f"Train {config.name}") logger.info(f"- PyTorch version: {torch.__version__}") logger.info(f"- Ignite version: {ignite.__version__}") logger.info(f"- CUDA version: {torch.version.cuda}") logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}") logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}") logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}") if idist.get_world_size() > 1: logger.info("Distributed setting:\n") idist.show_config() def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): if setup_random_seed: manual_seed(config.misc.random_seed + idist.get_rank()) output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir) config.output_dir = str(output_dir) if setup_output_dir and config.resume_from is None: if output_dir.exists(): assert len(list(output_dir.glob("events*"))) == 0, f"{output_dir} containers tensorboard event" if (output_dir / "train.log").exists() and idist.get_rank() == 0: (output_dir / "train.log").unlink() else: if idist.get_rank() == 0: output_dir.mkdir(parents=True) print(f"mkdir -p {output_dir}") if backup_config and idist.get_rank() == 0: with open(output_dir / "config.yml", "w+") as f: print(config.pretty(), file=f) logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log") logger.info(f"output path: {config.output_dir}") log_basic_info(logger, config) OmegaConf.set_readonly(config, True) engine = import_module(f"engine.{config.engine}") engine.run(task, config, logger) def run(task, config: str, *omega_options, **kwargs): omega_options = [str(o) for o in omega_options] cli_conf = OmegaConf.from_cli(omega_options) if len(cli_conf) > 0: print(cli_conf.pretty()) conf = OmegaConf.merge(OmegaConf.load(config), cli_conf) backend = kwargs.get("backend", "nccl") backup_config = kwargs.get("backup_config", False) setup_output_dir = kwargs.get("setup_output_dir", False) setup_random_seed = kwargs.get("setup_random_seed", False) assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True with idist.Parallel(backend=backend) as parallel: parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir, setup_random_seed=setup_random_seed) if __name__ == '__main__': fire.Fire(run)