28 lines
846 B
Python
28 lines
846 B
Python
import torch
|
|
import torch.optim as optim
|
|
import ignite.distributed as idist
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from model import MODEL
|
|
from util.distributed import auto_model
|
|
|
|
|
|
def build_model(cfg, distributed_args=None):
|
|
cfg = OmegaConf.to_container(cfg)
|
|
model_distributed_config = cfg.pop("_distributed", dict())
|
|
model = MODEL.build_with(cfg)
|
|
|
|
if model_distributed_config.get("bn_to_syncbn"):
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
|
return auto_model(model, **distributed_args)
|
|
|
|
|
|
def build_optimizer(params, cfg):
|
|
assert "_type" in cfg
|
|
cfg = OmegaConf.to_container(cfg)
|
|
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
|
return idist.auto_optim(optimizer)
|