add new normalization
This commit is contained in:
parent
9e8e73c988
commit
0841d03b3c
@ -60,34 +60,31 @@ class GANImageBuffer(object):
|
|||||||
|
|
||||||
@MODEL.register_module()
|
@MODEL.register_module()
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
if use_bias is None:
|
if use_bias is None:
|
||||||
# Only for IN, use bias since it does not have affine parameters.
|
# Only for IN, use bias since it does not have affine parameters.
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
norm_layer = select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
models = [nn.Sequential(
|
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
bias=use_bias)
|
||||||
norm_layer(num_channels),
|
self.norm1 = norm_layer(num_channels)
|
||||||
nn.ReLU(inplace=True),
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
)]
|
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||||
if use_dropout:
|
bias=use_bias)
|
||||||
models.append(nn.Dropout(0.5))
|
self.norm2 = norm_layer(num_channels)
|
||||||
models.append(nn.Sequential(
|
|
||||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
|
||||||
norm_layer(num_channels),
|
|
||||||
))
|
|
||||||
self.block = nn.Sequential(*models)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.block(x)
|
res = x
|
||||||
|
x = self.relu1(self.norm1(self.conv1(x)))
|
||||||
|
x = self.norm2(self.conv2(x))
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
|
||||||
@MODEL.register_module()
|
@MODEL.register_module()
|
||||||
class ResGenerator(nn.Module):
|
class ResGenerator(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||||
norm_type="IN", use_dropout=False):
|
norm_type="IN"):
|
||||||
super(ResGenerator, self).__init__()
|
super(ResGenerator, self).__init__()
|
||||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||||
norm_layer = select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
@ -115,7 +112,7 @@ class ResGenerator(nn.Module):
|
|||||||
|
|
||||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||||
self.resnet_middle = nn.Sequential(
|
self.resnet_middle = nn.Sequential(
|
||||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||||
range(num_blocks)])
|
range(num_blocks)])
|
||||||
|
|
||||||
# up sampling
|
# up sampling
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import functools
|
import functools
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def select_norm_layer(norm_type):
|
def select_norm_layer(norm_type):
|
||||||
@ -7,7 +8,69 @@ def select_norm_layer(norm_type):
|
|||||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||||
elif norm_type == "IN":
|
elif norm_type == "IN":
|
||||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||||
|
elif norm_type == "LN":
|
||||||
|
return functools.partial(LayerNorm2d, affine=True)
|
||||||
elif norm_type == "NONE":
|
elif norm_type == "NONE":
|
||||||
return lambda x: nn.Identity()
|
return lambda num_features: nn.Identity()
|
||||||
|
elif norm_type == "AdaIN":
|
||||||
|
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
||||||
else:
|
else:
|
||||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2d(nn.Module):
|
||||||
|
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.affine = affine
|
||||||
|
if self.affine:
|
||||||
|
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||||
|
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.affine:
|
||||||
|
nn.init.uniform_(self.channel_gamma)
|
||||||
|
nn.init.zeros_(self.channel_beta)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||||
|
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||||
|
print(x.size())
|
||||||
|
if self.affine:
|
||||||
|
return self.channel_gamma * x + self.channel_beta
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveInstanceNorm2d(nn.Module):
|
||||||
|
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
||||||
|
affine: bool = False, track_running_stats: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.affine = affine
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
||||||
|
|
||||||
|
self.gamma = None
|
||||||
|
self.beta = None
|
||||||
|
self.have_set_style = False
|
||||||
|
|
||||||
|
def set_style(self, style):
|
||||||
|
style = style.view(*style.size(), 1, 1)
|
||||||
|
self.gamma, self.beta = style.chunk(2, 1)
|
||||||
|
self.have_set_style = True
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert self.have_set_style
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.gamma * x + self.beta
|
||||||
|
self.have_set_style = False
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||||
|
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user