func to apply sn
This commit is contained in:
parent
0f2b67e215
commit
16f18ab2e2
@ -2,6 +2,15 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def add_spectral_norm(module):
|
||||||
|
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||||
|
return nn.utils.spectral_norm(module)
|
||||||
|
else:
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user