move sn to engine

This commit is contained in:
Ray Wong 2020-10-11 23:35:29 +08:00
parent 436bca88b4
commit 74a7cfb2d8
2 changed files with 8 additions and 11 deletions

View File

@ -1,10 +1,17 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
from util.misc import add_spectral_norm
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg):

View File

@ -4,16 +4,6 @@ import pkgutil
from pathlib import Path
from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages