76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import torch.nn as nn
|
|
import functools
|
|
import torch
|
|
|
|
|
|
def select_norm_layer(norm_type):
|
|
if norm_type == "BN":
|
|
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
|
elif norm_type == "IN":
|
|
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
elif norm_type == "LN":
|
|
return functools.partial(LayerNorm2d, affine=True)
|
|
elif norm_type == "NONE":
|
|
return lambda num_features: nn.Identity()
|
|
elif norm_type == "AdaIN":
|
|
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
|
else:
|
|
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
|
|
|
|
|
class LayerNorm2d(nn.Module):
|
|
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.affine = affine
|
|
if self.affine:
|
|
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
if self.affine:
|
|
nn.init.uniform_(self.channel_gamma)
|
|
nn.init.zeros_(self.channel_beta)
|
|
|
|
def forward(self, x):
|
|
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
|
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
|
if self.affine:
|
|
return self.channel_gamma * x + self.channel_beta
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
|
|
|
|
|
class AdaptiveInstanceNorm2d(nn.Module):
|
|
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
|
affine: bool = False, track_running_stats: bool = False):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
|
|
|
self.gamma = None
|
|
self.beta = None
|
|
self.have_set_style = False
|
|
|
|
def set_style(self, style):
|
|
style = style.view(*style.size(), 1, 1)
|
|
self.gamma, self.beta = style.chunk(2, 1)
|
|
self.have_set_style = True
|
|
|
|
def forward(self, x):
|
|
assert self.have_set_style
|
|
x = self.norm(x)
|
|
x = self.gamma * x + self.beta
|
|
self.have_set_style = False
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
|
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|