diff --git a/util/misc.py b/util/misc.py index 2bd6d98..ca8b49c 100644 --- a/util/misc.py +++ b/util/misc.py @@ -2,6 +2,15 @@ import logging from pathlib import Path from typing import Optional +import torch.nn as nn + + +def add_spectral_norm(module): + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'): + return nn.utils.spectral_norm(module) + else: + return module + def setup_logger( name: Optional[str] = None,