raycv/model/base/normalization.py

143 lines
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from model import NORMALIZATION
from model.base.module import Conv2dBlock
_VALID_NORM_AND_ABBREVIATION = dict(
IN="InstanceNorm2d",
BN="BatchNorm2d",
)
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN"):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_condition = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
def __init__(self, num_features: int):
super().__init__(num_features, "IN")
self.num_features = num_features
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
gamma, beta = style.chunk(2, 1)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
def set_feature(self, feature):
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
def set_condition_image(self, condition_image):
feature = self.base_conv_block(condition_image)
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma + beta
return out
@NORMALIZATION.register_module("ILN")
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.gamma = nn.Parameter(torch.Tensor(num_features))
self.beta = nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.rho)
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
@NORMALIZATION.register_module("AdaILN")
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.rho.data.fill_(default_rho)
self.gamma = None
self.beta = None
self.have_set_condition = False
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
self.have_set_condition = False
return out