import ignite.distributed as idist import torch import torch.optim as optim from omegaconf import OmegaConf from model import MODEL from util.misc import add_spectral_norm def build_model(cfg): cfg = OmegaConf.to_container(cfg) bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False) model = MODEL.build_with(cfg) if bn_to_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if add_spectral_norm_flag: model.apply(add_spectral_norm) return idist.auto_model(model) def build_optimizer(params, cfg): assert "_type" in cfg cfg = OmegaConf.to_container(cfg) optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) return idist.auto_optim(optimizer)