diff --git a/engine/util/build.py b/engine/util/build.py index b423586..5fd4e5d 100644 --- a/engine/util/build.py +++ b/engine/util/build.py @@ -1,10 +1,17 @@ import ignite.distributed as idist import torch +import torch.nn as nn import torch.optim as optim from omegaconf import OmegaConf from model import MODEL -from util.misc import add_spectral_norm + + +def add_spectral_norm(module): + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'): + return nn.utils.spectral_norm(module) + else: + return module def build_model(cfg): diff --git a/util/misc.py b/util/misc.py index 7dfb5c7..7f3edb7 100644 --- a/util/misc.py +++ b/util/misc.py @@ -4,16 +4,6 @@ import pkgutil from pathlib import Path from typing import Optional -import torch.nn as nn - - -def add_spectral_norm(module): - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'): - return nn.utils.spectral_norm(module) - else: - return module - - def import_submodules(package, recursive=True): """ Import all submodules of a module, recursively, including subpackages