base model, Norm&Conv&ResNet
This commit is contained in:
parent
acf243cb12
commit
0f2b67e215
0
model/base/__init__.py
Normal file
0
model/base/__init__.py
Normal file
109
model/base/module.py
Normal file
109
model/base/module.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model.registry import NORMALIZATION
|
||||||
|
|
||||||
|
_DO_NO_THING_FUNC = lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def _use_bias_checker(norm_type):
|
||||||
|
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalization(norm, num_features, additional_kwargs=None):
|
||||||
|
if norm == "NONE":
|
||||||
|
return _DO_NO_THING_FUNC
|
||||||
|
|
||||||
|
if additional_kwargs is None:
|
||||||
|
additional_kwargs = {}
|
||||||
|
kwargs = dict(_type=norm, num_features=num_features)
|
||||||
|
kwargs.update(additional_kwargs)
|
||||||
|
return NORMALIZATION.build_with(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _activation(activation):
|
||||||
|
if activation == "NONE":
|
||||||
|
return _DO_NO_THING_FUNC
|
||||||
|
elif activation == "ReLU":
|
||||||
|
return nn.ReLU(inplace=True)
|
||||||
|
elif activation == "LeakyReLU":
|
||||||
|
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
elif activation == "Tanh":
|
||||||
|
return nn.Tanh()
|
||||||
|
else:
|
||||||
|
raise NotImplemented(activation)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||||
|
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.activation_type = activation_type
|
||||||
|
|
||||||
|
# if caller not set bias, set bias automatically.
|
||||||
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
|
|
||||||
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
|
self.normalization = _normalization(norm_type, out_channels)
|
||||||
|
self.activation = _activation(activation_type)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.activation(self.normalization(self.convolution(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
||||||
|
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_type = norm_type
|
||||||
|
|
||||||
|
if out_channels is None:
|
||||||
|
out_channels = num_channels
|
||||||
|
if out_activation_type is None:
|
||||||
|
out_activation_type = "NONE"
|
||||||
|
|
||||||
|
self.learn_skip_connection = num_channels != out_channels
|
||||||
|
|
||||||
|
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
|
norm_type=norm_type, activation_type=activation_type)
|
||||||
|
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
|
norm_type=norm_type, activation_type=out_activation_type)
|
||||||
|
|
||||||
|
if self.learn_skip_connection:
|
||||||
|
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
|
norm_type=norm_type, activation_type=out_activation_type)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||||
|
return self.conv2(self.conv1(x)) + res
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseConv2dBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int,
|
||||||
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||||
|
self.activation = _activation(activation_type)
|
||||||
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.convolution(self.activation(self.normalization(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
||||||
|
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
||||||
|
super().__init__()
|
||||||
|
self.learn_skip_connection = in_channels != out_channels
|
||||||
|
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||||
|
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||||
|
if self.learn_skip_connection:
|
||||||
|
self.res_conv = ReverseConv2dBlock(
|
||||||
|
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||||
|
return self.conv2(self.conv1(x)) + res
|
||||||
142
model/base/normalization.py
Normal file
142
model/base/normalization.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from model import NORMALIZATION
|
||||||
|
from model.base.module import Conv2dBlock
|
||||||
|
|
||||||
|
_VALID_NORM_AND_ABBREVIATION = dict(
|
||||||
|
IN="InstanceNorm2d",
|
||||||
|
BN="BatchNorm2d",
|
||||||
|
)
|
||||||
|
|
||||||
|
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||||
|
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("ADE")
|
||||||
|
class AdaptiveDenormalization(nn.Module):
|
||||||
|
def __init__(self, num_features, base_norm_type="BN"):
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.base_norm_type = base_norm_type
|
||||||
|
self.norm = self.base_norm(num_features)
|
||||||
|
self.gamma = None
|
||||||
|
self.beta = None
|
||||||
|
self.have_set_condition = False
|
||||||
|
|
||||||
|
def base_norm(self, num_features):
|
||||||
|
if self.base_norm_type == "IN":
|
||||||
|
return nn.InstanceNorm2d(num_features)
|
||||||
|
elif self.base_norm_type == "BN":
|
||||||
|
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||||
|
|
||||||
|
def set_condition(self, gamma, beta):
|
||||||
|
self.gamma, self.beta = gamma, beta
|
||||||
|
self.have_set_condition = True
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert self.have_set_condition
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.gamma * x + self.beta
|
||||||
|
self.have_set_condition = False
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||||
|
f"base_norm_type={self.base_norm_type})"
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("AdaIN")
|
||||||
|
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||||
|
def __init__(self, num_features: int):
|
||||||
|
super().__init__(num_features, "IN")
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
def set_style(self, style):
|
||||||
|
style = style.view(*style.size(), 1, 1)
|
||||||
|
gamma, beta = style.chunk(2, 1)
|
||||||
|
super().set_condition(gamma, beta)
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("FADE")
|
||||||
|
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||||
|
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
|
||||||
|
super().__init__(num_features, base_norm_type)
|
||||||
|
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||||
|
padding_mode=padding_mode)
|
||||||
|
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||||
|
padding_mode=padding_mode)
|
||||||
|
|
||||||
|
def set_feature(self, feature):
|
||||||
|
gamma = self.gamma_conv(feature)
|
||||||
|
beta = self.beta_conv(feature)
|
||||||
|
super().set_condition(gamma, beta)
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("SPADE")
|
||||||
|
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||||
|
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||||
|
activation_type="ReLU", padding_mode="zeros"):
|
||||||
|
super().__init__(num_features, base_norm_type)
|
||||||
|
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
|
||||||
|
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||||
|
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||||
|
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||||
|
|
||||||
|
def set_condition_image(self, condition_image):
|
||||||
|
feature = self.base_conv_block(condition_image)
|
||||||
|
gamma = self.gamma_conv(feature)
|
||||||
|
beta = self.beta_conv(feature)
|
||||||
|
super().set_condition(gamma, beta)
|
||||||
|
|
||||||
|
|
||||||
|
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||||
|
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||||
|
out = out * gamma + beta
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("ILN")
|
||||||
|
class ILN(nn.Module):
|
||||||
|
def __init__(self, num_features, eps=1e-5):
|
||||||
|
super(ILN, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||||
|
self.gamma = nn.Parameter(torch.Tensor(num_features))
|
||||||
|
self.beta = nn.Parameter(torch.Tensor(num_features))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.zeros_(self.rho)
|
||||||
|
nn.init.ones_(self.gamma)
|
||||||
|
nn.init.zeros_(self.beta)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return _instance_layer_normalization(
|
||||||
|
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@NORMALIZATION.register_module("AdaILN")
|
||||||
|
class AdaILN(nn.Module):
|
||||||
|
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||||
|
super(AdaILN, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||||
|
self.rho.data.fill_(default_rho)
|
||||||
|
|
||||||
|
self.gamma = None
|
||||||
|
self.beta = None
|
||||||
|
self.have_set_condition = False
|
||||||
|
|
||||||
|
def set_condition(self, gamma, beta):
|
||||||
|
self.gamma, self.beta = gamma, beta
|
||||||
|
self.have_set_condition = True
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert self.have_set_condition
|
||||||
|
out = _instance_layer_normalization(
|
||||||
|
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||||
|
self.have_set_condition = False
|
||||||
|
return out
|
||||||
Loading…
Reference in New Issue
Block a user