import torch import torch.nn as nn import torch.nn.functional as F from model import NORMALIZATION from model.base.module import Conv2dBlock _VALID_NORM_AND_ABBREVIATION = dict( IN="InstanceNorm2d", BN="BatchNorm2d", ) for abbr, name in _VALID_NORM_AND_ABBREVIATION.items(): NORMALIZATION.register_module(module=getattr(nn, name), name=abbr) @NORMALIZATION.register_module("ADE") class AdaptiveDenormalization(nn.Module): def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0): super().__init__() self.num_features = num_features self.base_norm_type = base_norm_type self.norm = self.base_norm(num_features) self.gamma = None self.gamma_bias = gamma_bias self.beta = None self.have_set_condition = False def base_norm(self, num_features): if self.base_norm_type == "IN": return nn.InstanceNorm2d(num_features, affine=False) elif self.base_norm_type == "BN": return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True) def set_condition(self, gamma, beta): self.gamma, self.beta = gamma, beta self.have_set_condition = True def forward(self, x): assert self.have_set_condition x = self.norm(x) x = (self.gamma + self.gamma_bias) * x + self.beta self.have_set_condition = False return x # # def __repr__(self): # return f"{self.__class__.__name__}(num_features={self.num_features}, " \ # f"base_norm_type={self.base_norm_type})" @NORMALIZATION.register_module("AdaIN") class AdaptiveInstanceNorm2d(AdaptiveDenormalization): def __init__(self, num_features: int): super().__init__(num_features, "IN") self.num_features = num_features def set_style(self, style): style = style.view(*style.size(), 1, 1) gamma, beta = style.chunk(2, 1) super().set_condition(gamma, beta) @NORMALIZATION.register_module("FADE") class FeatureAdaptiveDenormalization(AdaptiveDenormalization): def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0): super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias) self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, padding_mode=padding_mode) self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, padding_mode=padding_mode) def set_feature(self, feature): gamma = self.gamma_conv(feature) beta = self.beta_conv(feature) super().set_condition(gamma, beta) @NORMALIZATION.register_module("SPADE") class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization): def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN", activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0): super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias) self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type, kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE") self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) def set_condition_image(self, condition_image): feature = self.base_conv_block(condition_image) gamma = self.gamma_conv(feature) beta = self.beta_conv(feature) super().set_condition(gamma, beta) def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps) out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) return out @NORMALIZATION.register_module("ILN") class ILN(nn.Module): def __init__(self, num_features, eps=1e-5): super(ILN, self).__init__() self.eps = eps self.rho = nn.Parameter(torch.Tensor(num_features)) self.gamma = nn.Parameter(torch.Tensor(num_features)) self.beta = nn.Parameter(torch.Tensor(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.zeros_(self.rho) nn.init.ones_(self.gamma) nn.init.zeros_(self.beta) def forward(self, x): return _instance_layer_normalization( x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps) @NORMALIZATION.register_module("AdaILN") class AdaILN(nn.Module): def __init__(self, num_features, eps=1e-5, default_rho=0.9): super(AdaILN, self).__init__() self.eps = eps self.rho = nn.Parameter(torch.Tensor(num_features)) self.rho.data.fill_(default_rho) self.gamma = None self.beta = None self.have_set_condition = False def set_condition(self, gamma, beta): self.gamma, self.beta = gamma, beta self.have_set_condition = True def forward(self, x): assert self.have_set_condition out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps) self.have_set_condition = False return out