diff --git a/model/base/__init__.py b/model/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/base/module.py b/model/base/module.py new file mode 100644 index 0000000..d0c6da6 --- /dev/null +++ b/model/base/module.py @@ -0,0 +1,109 @@ +import torch.nn as nn + +from model.registry import NORMALIZATION + +_DO_NO_THING_FUNC = lambda x: x + + +def _use_bias_checker(norm_type): + return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"] + + +def _normalization(norm, num_features, additional_kwargs=None): + if norm == "NONE": + return _DO_NO_THING_FUNC + + if additional_kwargs is None: + additional_kwargs = {} + kwargs = dict(_type=norm, num_features=num_features) + kwargs.update(additional_kwargs) + return NORMALIZATION.build_with(kwargs) + + +def _activation(activation): + if activation == "NONE": + return _DO_NO_THING_FUNC + elif activation == "ReLU": + return nn.ReLU(inplace=True) + elif activation == "LeakyReLU": + return nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation == "Tanh": + return nn.Tanh() + else: + raise NotImplemented(activation) + + +class Conv2dBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, bias=None, + activation_type="ReLU", norm_type="NONE", **conv_kwargs): + super().__init__() + self.norm_type = norm_type + self.activation_type = activation_type + + # if caller not set bias, set bias automatically. + conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias + + self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + self.normalization = _normalization(norm_type, out_channels) + self.activation = _activation(activation_type) + + def forward(self, x): + return self.activation(self.normalization(self.convolution(x))) + + +class ResidualBlock(nn.Module): + def __init__(self, num_channels, out_channels=None, padding_mode='reflect', + activation_type="ReLU", out_activation_type=None, norm_type="IN"): + super().__init__() + self.norm_type = norm_type + + if out_channels is None: + out_channels = num_channels + if out_activation_type is None: + out_activation_type = "NONE" + + self.learn_skip_connection = num_channels != out_channels + + self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=activation_type) + self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=out_activation_type) + + if self.learn_skip_connection: + self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + norm_type=norm_type, activation_type=out_activation_type) + + def forward(self, x): + res = x if not self.learn_skip_connection else self.res_conv(x) + return self.conv2(self.conv1(x)) + res + + +class ReverseConv2dBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, + activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs): + super().__init__() + self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) + self.activation = _activation(activation_type) + self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + + def forward(self, x): + return self.convolution(self.activation(self.normalization(x))) + + +class ReverseResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, padding_mode="reflect", + norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"): + super().__init__() + self.learn_skip_connection = in_channels != out_channels + self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs, + kernel_size=3, padding=1, padding_mode=padding_mode) + self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs, + kernel_size=3, padding=1, padding_mode=padding_mode) + if self.learn_skip_connection: + self.res_conv = ReverseConv2dBlock( + in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs, + kernel_size=3, padding=1, padding_mode=padding_mode) + + def forward(self, x): + res = x if not self.learn_skip_connection else self.res_conv(x) + return self.conv2(self.conv1(x)) + res diff --git a/model/base/normalization.py b/model/base/normalization.py new file mode 100644 index 0000000..7a925f4 --- /dev/null +++ b/model/base/normalization.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model import NORMALIZATION +from model.base.module import Conv2dBlock + +_VALID_NORM_AND_ABBREVIATION = dict( + IN="InstanceNorm2d", + BN="BatchNorm2d", +) + +for abbr, name in _VALID_NORM_AND_ABBREVIATION.items(): + NORMALIZATION.register_module(module=getattr(nn, name), name=abbr) + + +@NORMALIZATION.register_module("ADE") +class AdaptiveDenormalization(nn.Module): + def __init__(self, num_features, base_norm_type="BN"): + super().__init__() + self.num_features = num_features + self.base_norm_type = base_norm_type + self.norm = self.base_norm(num_features) + self.gamma = None + self.beta = None + self.have_set_condition = False + + def base_norm(self, num_features): + if self.base_norm_type == "IN": + return nn.InstanceNorm2d(num_features) + elif self.base_norm_type == "BN": + return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True) + + def set_condition(self, gamma, beta): + self.gamma, self.beta = gamma, beta + self.have_set_condition = True + + def forward(self, x): + assert self.have_set_condition + x = self.norm(x) + x = self.gamma * x + self.beta + self.have_set_condition = False + return x + + def __repr__(self): + return f"{self.__class__.__name__}(num_features={self.num_features}, " \ + f"base_norm_type={self.base_norm_type})" + + +@NORMALIZATION.register_module("AdaIN") +class AdaptiveInstanceNorm2d(AdaptiveDenormalization): + def __init__(self, num_features: int): + super().__init__(num_features, "IN") + self.num_features = num_features + + def set_style(self, style): + style = style.view(*style.size(), 1, 1) + gamma, beta = style.chunk(2, 1) + super().set_condition(gamma, beta) + + +@NORMALIZATION.register_module("FADE") +class FeatureAdaptiveDenormalization(AdaptiveDenormalization): + def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"): + super().__init__(num_features, base_norm_type) + self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, + padding_mode=padding_mode) + self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, + padding_mode=padding_mode) + + def set_feature(self, feature): + gamma = self.gamma_conv(feature) + beta = self.beta_conv(feature) + super().set_condition(gamma, beta) + + +@NORMALIZATION.register_module("SPADE") +class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization): + def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN", + activation_type="ReLU", padding_mode="zeros"): + super().__init__(num_features, base_norm_type) + self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type, + kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE") + self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) + self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) + + def set_condition_image(self, condition_image): + feature = self.base_conv_block(condition_image) + gamma = self.gamma_conv(feature) + beta = self.beta_conv(feature) + super().set_condition(gamma, beta) + + +def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): + out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps) + out = out * gamma + beta + return out + + +@NORMALIZATION.register_module("ILN") +class ILN(nn.Module): + def __init__(self, num_features, eps=1e-5): + super(ILN, self).__init__() + self.eps = eps + self.rho = nn.Parameter(torch.Tensor(num_features)) + self.gamma = nn.Parameter(torch.Tensor(num_features)) + self.beta = nn.Parameter(torch.Tensor(num_features)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.zeros_(self.rho) + nn.init.ones_(self.gamma) + nn.init.zeros_(self.beta) + + def forward(self, x): + return _instance_layer_normalization( + x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps) + + +@NORMALIZATION.register_module("AdaILN") +class AdaILN(nn.Module): + def __init__(self, num_features, eps=1e-5, default_rho=0.9): + super(AdaILN, self).__init__() + self.eps = eps + self.rho = nn.Parameter(torch.Tensor(num_features)) + self.rho.data.fill_(default_rho) + + self.gamma = None + self.beta = None + self.have_set_condition = False + + def set_condition(self, gamma, beta): + self.gamma, self.beta = gamma, beta + self.have_set_condition = True + + def forward(self, x): + assert self.have_set_condition + out = _instance_layer_normalization( + x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps) + self.have_set_condition = False + return out