add new normalization
This commit is contained in:
parent
9e8e73c988
commit
0841d03b3c
@ -60,34 +60,31 @@ class GANImageBuffer(object):
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN", use_dropout=False):
|
||||
norm_type="IN"):
|
||||
super(ResGenerator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
@ -115,7 +112,7 @@ class ResGenerator(nn.Module):
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
import torch
|
||||
|
||||
|
||||
def select_norm_layer(norm_type):
|
||||
@ -7,7 +8,69 @@ def select_norm_layer(norm_type):
|
||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == "IN":
|
||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == "LN":
|
||||
return functools.partial(LayerNorm2d, affine=True)
|
||||
elif norm_type == "NONE":
|
||||
return lambda x: nn.Identity()
|
||||
return lambda num_features: nn.Identity()
|
||||
elif norm_type == "AdaIN":
|
||||
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
||||
else:
|
||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.affine:
|
||||
nn.init.uniform_(self.channel_gamma)
|
||||
nn.init.zeros_(self.channel_beta)
|
||||
|
||||
def forward(self, x):
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||
print(x.size())
|
||||
if self.affine:
|
||||
return self.channel_gamma * x + self.channel_beta
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm2d(nn.Module):
|
||||
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = False, track_running_stats: bool = False):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_style = False
|
||||
|
||||
def set_style(self, style):
|
||||
style = style.view(*style.size(), 1, 1)
|
||||
self.gamma, self.beta = style.chunk(2, 1)
|
||||
self.have_set_style = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_style
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
self.have_set_style = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user