add GauGAN

This commit is contained in:
Ray Wong 2020-10-11 23:05:38 +08:00
parent 06b2abd19a
commit 6070f08835
3 changed files with 120 additions and 23 deletions

View File

@ -20,13 +20,13 @@ def _normalization(norm, num_features, additional_kwargs=None):
return NORMALIZATION.build_with(kwargs)
def _activation(activation):
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=True)
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
@ -74,7 +74,7 @@ class ReverseConv2dBlock(nn.Module):
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
@ -84,7 +84,7 @@ class ReverseConv2dBlock(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None):
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
@ -110,15 +110,15 @@ class ResidualBlock(nn.Module):
self.learn_skip_connection = in_channels != out_channels
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs,
padding_mode=padding_mode)
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)

View File

@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN"):
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.gamma_bias = gamma_bias
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features)
return nn.InstanceNorm2d(num_features, affine=False)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = self.gamma * x + self.beta
x = (self.gamma + self.gamma_bias) * x + self.beta
self.have_set_condition = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"base_norm_type={self.base_norm_type})"
#
# def __repr__(self):
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
# f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
def __init__(self, num_features: int, condition_in_channels,
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)

View File

@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 0
max_multiple = 3
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
self.sequence = nn.Sequential(*sequence)
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
def forward(self, x):
x = self.sequence(x)
x = x.view(x.size(0), -1)
return self.fc_avg(x), self.fc_var(x)
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.sx, self.sy = start_size
self.use_vae = use_vae
self.num_z_dim = num_z_dim
if use_vae:
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
else:
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
if i != num_blocks - 1:
sequence.append(nn.Upsample(scale_factor=2))
sequence.append(ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = nn.Sequential(
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
nn.Tanh()
)
def forward(self, seg, z=None):
if self.use_vae:
if z is None:
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
else:
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
for blk in self.sequence:
if isinstance(blk, ResidualBlock):
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
blk.conv1.normalization.set_condition_image(downsampling_seg)
blk.conv2.normalization.set_condition_image(downsampling_seg)
if blk.learn_skip_connection:
blk.res_conv.normalization.set_condition_image(downsampling_seg)
x = blk(x)
return self.output_converter(x)
if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256)
print(g)
print(g(torch.randn(2, 3, 256, 256)).size())