func to apply sn

This commit is contained in:
Ray Wong 2020-09-26 17:47:24 +08:00
parent 0f2b67e215
commit 16f18ab2e2

View File

@ -2,6 +2,15 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def setup_logger( def setup_logger(
name: Optional[str] = None, name: Optional[str] = None,