add new normalization

This commit is contained in:
Ray Wong 2020-08-29 10:35:54 +08:00
parent 9e8e73c988
commit 0841d03b3c
2 changed files with 78 additions and 18 deletions

View File

@ -60,34 +60,31 @@ class GANImageBuffer(object):
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
models = [nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
nn.ReLU(inplace=True),
)]
if use_dropout:
models.append(nn.Dropout(0.5))
models.append(nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
))
self.block = nn.Sequential(*models)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
return x + self.block(x)
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
@MODEL.register_module()
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN", use_dropout=False):
norm_type="IN"):
super(ResGenerator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
@ -115,7 +112,7 @@ class ResGenerator(nn.Module):
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling

View File

@ -1,5 +1,6 @@
import torch.nn as nn
import functools
import torch
def select_norm_layer(norm_type):
@ -7,7 +8,69 @@ def select_norm_layer(norm_type):
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":
return functools.partial(LayerNorm2d, affine=True)
elif norm_type == "NONE":
return lambda x: nn.Identity()
return lambda num_features: nn.Identity()
elif norm_type == "AdaIN":
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
class LayerNorm2d(nn.Module):
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.uniform_(self.channel_gamma)
nn.init.zeros_(self.channel_beta)
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
print(x.size())
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False):
super().__init__()
self.num_features = num_features
self.affine = affine
self.track_running_stats = track_running_stats
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.gamma = None
self.beta = None
self.have_set_style = False
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
self.gamma, self.beta = style.chunk(2, 1)
self.have_set_style = True
def forward(self, x):
assert self.have_set_style
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_style = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"affine={self.affine}, track_running_stats={self.track_running_stats})"