import torch import torch.optim as optim import ignite.distributed as idist from omegaconf import OmegaConf from model import MODEL from util.distributed import auto_model def build_model(cfg, distributed_args=None): cfg = OmegaConf.to_container(cfg) model_distributed_config = cfg.pop("_distributed", dict()) model = MODEL.build_with(cfg) if model_distributed_config.get("bn_to_syncbn"): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args return auto_model(model, **distributed_args) def build_optimizer(params, cfg): assert "_type" in cfg cfg = OmegaConf.to_container(cfg) optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) return idist.auto_optim(optimizer)