move sn to engine
This commit is contained in:
parent
436bca88b4
commit
74a7cfb2d8
@ -1,10 +1,17 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.misc import add_spectral_norm
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
|
||||
10
util/misc.py
10
util/misc.py
@ -4,16 +4,6 @@ import pkgutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def import_submodules(package, recursive=True):
|
||||
""" Import all submodules of a module, recursively, including subpackages
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user