27 lines
794 B
Python
27 lines
794 B
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.optim as optim
|
|
from omegaconf import OmegaConf
|
|
|
|
from model import MODEL
|
|
from util.misc import add_spectral_norm
|
|
|
|
|
|
def build_model(cfg):
|
|
cfg = OmegaConf.to_container(cfg)
|
|
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
|
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
|
model = MODEL.build_with(cfg)
|
|
if bn_to_sync_bn:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
if add_spectral_norm_flag:
|
|
model.apply(add_spectral_norm)
|
|
return idist.auto_model(model)
|
|
|
|
|
|
def build_optimizer(params, cfg):
|
|
assert "_type" in cfg
|
|
cfg = OmegaConf.to_container(cfg)
|
|
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
|
return idist.auto_optim(optimizer)
|