34 lines
994 B
Python
34 lines
994 B
Python
import ignite.distributed as idist
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from omegaconf import OmegaConf
|
|
|
|
from model import MODEL
|
|
|
|
|
|
def add_spectral_norm(module):
|
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
|
return nn.utils.spectral_norm(module)
|
|
else:
|
|
return module
|
|
|
|
|
|
def build_model(cfg):
|
|
cfg = OmegaConf.to_container(cfg)
|
|
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
|
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
|
model = MODEL.build_with(cfg)
|
|
if bn_to_sync_bn:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
if add_spectral_norm_flag:
|
|
model.apply(add_spectral_norm)
|
|
return idist.auto_model(model)
|
|
|
|
|
|
def build_optimizer(params, cfg):
|
|
assert "_type" in cfg
|
|
cfg = OmegaConf.to_container(cfg)
|
|
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
|
return idist.auto_optim(optimizer)
|