func to apply sn
This commit is contained in:
parent
0f2b67e215
commit
16f18ab2e2
@ -2,6 +2,15 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: Optional[str] = None,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user