raycv/engine/util/build.py
2020-10-11 23:35:29 +08:00

34 lines
994 B
Python

import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)