From 0841d03b3c2bc6c3327ffefa9592614212df7f13 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sat, 29 Aug 2020 10:35:54 +0800 Subject: [PATCH] add new normalization --- model/GAN/residual_generator.py | 31 +++++++--------- model/normalization.py | 65 ++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 18 deletions(-) diff --git a/model/GAN/residual_generator.py b/model/GAN/residual_generator.py index 8a288e0..de8ae9b 100644 --- a/model/GAN/residual_generator.py +++ b/model/GAN/residual_generator.py @@ -60,34 +60,31 @@ class GANImageBuffer(object): @MODEL.register_module() class ResidualBlock(nn.Module): - def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None): + def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None): super(ResidualBlock, self).__init__() - if use_bias is None: # Only for IN, use bias since it does not have affine parameters. use_bias = norm_type == "IN" norm_layer = select_norm_layer(norm_type) - models = [nn.Sequential( - nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), - norm_layer(num_channels), - nn.ReLU(inplace=True), - )] - if use_dropout: - models.append(nn.Dropout(0.5)) - models.append(nn.Sequential( - nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), - norm_layer(num_channels), - )) - self.block = nn.Sequential(*models) + self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + bias=use_bias) + self.norm1 = norm_layer(num_channels) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + bias=use_bias) + self.norm2 = norm_layer(num_channels) def forward(self, x): - return x + self.block(x) + res = x + x = self.relu1(self.norm1(self.conv1(x))) + x = self.norm2(self.conv2(x)) + return x + res @MODEL.register_module() class ResGenerator(nn.Module): def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', - norm_type="IN", use_dropout=False): + norm_type="IN"): super(ResGenerator, self).__init__() assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' norm_layer = select_norm_layer(norm_type) @@ -115,7 +112,7 @@ class ResGenerator(nn.Module): res_block_channels = num_down_sampling ** 2 * base_channels self.resnet_middle = nn.Sequential( - *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in + *[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in range(num_blocks)]) # up sampling diff --git a/model/normalization.py b/model/normalization.py index 36413c0..acfbbbd 100644 --- a/model/normalization.py +++ b/model/normalization.py @@ -1,5 +1,6 @@ import torch.nn as nn import functools +import torch def select_norm_layer(norm_type): @@ -7,7 +8,69 @@ def select_norm_layer(norm_type): return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == "IN": return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == "LN": + return functools.partial(LayerNorm2d, affine=True) elif norm_type == "NONE": - return lambda x: nn.Identity() + return lambda num_features: nn.Identity() + elif norm_type == "AdaIN": + return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False) else: raise NotImplemented(f'normalization layer {norm_type} is not found') + + +class LayerNorm2d(nn.Module): + def __init__(self, num_features, eps: float = 1e-5, affine: bool = True): + super().__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + if self.affine: + self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + nn.init.uniform_(self.channel_gamma) + nn.init.zeros_(self.channel_beta) + + def forward(self, x): + ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) + x = (x - ln_mean) / torch.sqrt(ln_var + self.eps) + print(x.size()) + if self.affine: + return self.channel_gamma * x + self.channel_beta + return x + + def __repr__(self): + return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})" + + +class AdaptiveInstanceNorm2d(nn.Module): + def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, + affine: bool = False, track_running_stats: bool = False): + super().__init__() + self.num_features = num_features + self.affine = affine + self.track_running_stats = track_running_stats + self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats) + + self.gamma = None + self.beta = None + self.have_set_style = False + + def set_style(self, style): + style = style.view(*style.size(), 1, 1) + self.gamma, self.beta = style.chunk(2, 1) + self.have_set_style = True + + def forward(self, x): + assert self.have_set_style + x = self.norm(x) + x = self.gamma * x + self.beta + self.have_set_style = False + return x + + def __repr__(self): + return f"{self.__class__.__name__}(num_features={self.num_features}, " \ + f"affine={self.affine}, track_running_stats={self.track_running_stats})"