raycv/main.py
2020-09-10 18:34:52 +08:00

76 lines
2.9 KiB
Python

from importlib import import_module
from pathlib import Path
import fire
import ignite
import ignite.distributed as idist
import torch
from ignite.utils import manual_seed
from omegaconf import OmegaConf
from util.misc import setup_logger
def log_basic_info(logger, config):
logger.info(f"Train {config.name}")
logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if idist.get_world_size() > 1:
logger.info("Distributed setting:\n")
idist.show_config()
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank())
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:
if output_dir.exists():
assert len(list(output_dir.glob("events*"))) == 0
assert len(list(output_dir.glob("*.pt"))) == 0
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
(output_dir / "train.log").unlink()
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True)
print(f"mkdir -p {output_dir}")
if backup_config and idist.get_rank() == 0:
with open(output_dir / "config.yml", "w+") as f:
print(config.pretty(), file=f)
logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log")
logger.info(f"output path: {config.output_dir}")
log_basic_info(logger, config)
OmegaConf.set_readonly(config, True)
engine = import_module(f"engine.{config.engine}")
engine.run(task, config, logger)
def run(task, config: str, *omega_options, **kwargs):
omega_options = [str(o) for o in omega_options]
cli_conf = OmegaConf.from_cli(omega_options)
if len(cli_conf) > 0:
print(cli_conf.pretty())
conf = OmegaConf.merge(OmegaConf.load(config), cli_conf)
backend = kwargs.get("backend", "nccl")
backup_config = kwargs.get("backup_config", False)
setup_output_dir = kwargs.get("setup_output_dir", False)
setup_random_seed = kwargs.get("setup_random_seed", False)
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
with idist.Parallel(backend=backend) as parallel:
parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir,
setup_random_seed=setup_random_seed)
if __name__ == '__main__':
fire.Fire(run)