move sn to engine
This commit is contained in:
parent
436bca88b4
commit
74a7cfb2d8
@ -1,10 +1,17 @@
|
|||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from util.misc import add_spectral_norm
|
|
||||||
|
|
||||||
|
def add_spectral_norm(module):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||||
|
return nn.utils.spectral_norm(module)
|
||||||
|
else:
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def build_model(cfg):
|
def build_model(cfg):
|
||||||
|
|||||||
10
util/misc.py
10
util/misc.py
@ -4,16 +4,6 @@ import pkgutil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def add_spectral_norm(module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
|
||||||
return nn.utils.spectral_norm(module)
|
|
||||||
else:
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def import_submodules(package, recursive=True):
|
def import_submodules(package, recursive=True):
|
||||||
""" Import all submodules of a module, recursively, including subpackages
|
""" Import all submodules of a module, recursively, including subpackages
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user