import functools import torch import torch.nn as nn def select_norm_layer(norm_type): if norm_type == "BN": return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == "IN": return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == "LN": return functools.partial(LayerNorm2d, affine=True) elif norm_type == "NONE": return lambda num_features: nn.Identity() elif norm_type == "AdaIN": return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False) else: raise NotImplemented(f'normalization layer {norm_type} is not found') class LayerNorm2d(nn.Module): def __init__(self, num_features, eps: float = 1e-5, affine: bool = True): super().__init__() self.num_features = num_features self.eps = eps self.affine = affine if self.affine: self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.reset_parameters() def reset_parameters(self): if self.affine: nn.init.uniform_(self.channel_gamma) nn.init.zeros_(self.channel_beta) def forward(self, x): ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) x = (x - ln_mean) / torch.sqrt(ln_var + self.eps) if self.affine: return self.channel_gamma * x + self.channel_beta return x def __repr__(self): return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})" class AdaptiveInstanceNorm2d(nn.Module): def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = False, track_running_stats: bool = False): super().__init__() self.num_features = num_features self.affine = affine self.track_running_stats = track_running_stats self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats) self.gamma = None self.beta = None self.have_set_style = False def set_style(self, style): style = style.view(*style.size(), 1, 1) self.gamma, self.beta = style.chunk(2, 1) self.have_set_style = True def forward(self, x): assert self.have_set_style x = self.norm(x) x = self.gamma * x + self.beta self.have_set_style = False return x def __repr__(self): return f"{self.__class__.__name__}(num_features={self.num_features}, " \ f"affine={self.affine}, track_running_stats={self.track_running_stats})"