move sn to engine

This commit is contained in:
Ray Wong 2020-10-11 23:35:29 +08:00
parent 436bca88b4
commit 74a7cfb2d8
2 changed files with 8 additions and 11 deletions

View File

@ -1,10 +1,17 @@
import ignite.distributed as idist import ignite.distributed as idist
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from omegaconf import OmegaConf from omegaconf import OmegaConf
from model import MODEL from model import MODEL
from util.misc import add_spectral_norm
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg): def build_model(cfg):

View File

@ -4,16 +4,6 @@ import pkgutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def import_submodules(package, recursive=True): def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages """ Import all submodules of a module, recursively, including subpackages