From 6070f0883595f2af259c1146798d6cdc693dc657 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sun, 11 Oct 2020 23:05:38 +0800 Subject: [PATCH] add GauGAN --- model/base/module.py | 22 +++---- model/base/normalization.py | 26 +++++---- model/image_translation/GauGAN.py | 95 +++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 23 deletions(-) diff --git a/model/base/module.py b/model/base/module.py index 9597a43..c168d7a 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -20,13 +20,13 @@ def _normalization(norm, num_features, additional_kwargs=None): return NORMALIZATION.build_with(kwargs) -def _activation(activation): +def _activation(activation, inplace=True): if activation == "NONE": return _DO_NO_THING_FUNC elif activation == "ReLU": - return nn.ReLU(inplace=True) + return nn.ReLU(inplace=inplace) elif activation == "LeakyReLU": - return nn.LeakyReLU(negative_slope=0.2, inplace=True) + return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) elif activation == "Tanh": return nn.Tanh() else: @@ -74,7 +74,7 @@ class ReverseConv2dBlock(nn.Module): activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs): super().__init__() self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) - self.activation = _activation(activation_type) + self.activation = _activation(activation_type, inplace=False) self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) def forward(self, x): @@ -84,7 +84,7 @@ class ReverseConv2dBlock(nn.Module): class ResidualBlock(nn.Module): def __init__(self, in_channels, padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False, - out_channels=None, out_activation_type=None): + out_channels=None, out_activation_type=None, additional_norm_kwargs=None): """ Residual Conv Block :param in_channels: @@ -110,15 +110,15 @@ class ResidualBlock(nn.Module): self.learn_skip_connection = in_channels != out_channels conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock + conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type, + additional_norm_kwargs=additional_norm_kwargs, + padding_mode=padding_mode) - self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=activation_type) - self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=out_activation_type) + self.conv1 = conv_block(in_channels, in_channels, **conv_param) + self.conv2 = conv_block(in_channels, out_channels, **conv_param) if self.learn_skip_connection: - self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - norm_type=norm_type, activation_type=out_activation_type) + self.res_conv = conv_block(in_channels, out_channels, **conv_param) def forward(self, x): res = x if not self.learn_skip_connection else self.res_conv(x) diff --git a/model/base/normalization.py b/model/base/normalization.py index 30b2e12..aba7d30 100644 --- a/model/base/normalization.py +++ b/model/base/normalization.py @@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items(): @NORMALIZATION.register_module("ADE") class AdaptiveDenormalization(nn.Module): - def __init__(self, num_features, base_norm_type="BN"): + def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0): super().__init__() self.num_features = num_features self.base_norm_type = base_norm_type self.norm = self.base_norm(num_features) self.gamma = None + self.gamma_bias = gamma_bias self.beta = None self.have_set_condition = False def base_norm(self, num_features): if self.base_norm_type == "IN": - return nn.InstanceNorm2d(num_features) + return nn.InstanceNorm2d(num_features, affine=False) elif self.base_norm_type == "BN": return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True) @@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module): def forward(self, x): assert self.have_set_condition x = self.norm(x) - x = self.gamma * x + self.beta + x = (self.gamma + self.gamma_bias) * x + self.beta self.have_set_condition = False return x - - def __repr__(self): - return f"{self.__class__.__name__}(num_features={self.num_features}, " \ - f"base_norm_type={self.base_norm_type})" + # + # def __repr__(self): + # return f"{self.__class__.__name__}(num_features={self.num_features}, " \ + # f"base_norm_type={self.base_norm_type})" @NORMALIZATION.register_module("AdaIN") @@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization): @NORMALIZATION.register_module("FADE") class FeatureAdaptiveDenormalization(AdaptiveDenormalization): - def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"): - super().__init__(num_features, base_norm_type) + def __init__(self, num_features: int, condition_in_channels, + base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0): + super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias) self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, padding_mode=padding_mode) self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1, @@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization): @NORMALIZATION.register_module("SPADE") class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization): def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN", - activation_type="ReLU", padding_mode="zeros"): - super().__init__(num_features, base_norm_type) - self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type, + activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0): + super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias) + self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type, kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE") self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode) diff --git a/model/image_translation/GauGAN.py b/model/image_translation/GauGAN.py index e69de29..b65e502 100644 --- a/model/image_translation/GauGAN.py +++ b/model/image_translation/GauGAN.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock + + +class StyleEncoder(nn.Module): + def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64, + norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"): + super().__init__() + sequence = [Conv2dBlock( + in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )] + multiple_now = 0 + max_multiple = 3 + for i in range(1, num_conv + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** max_multiple) + sequence.append(Conv2dBlock( + multiple_prev * base_channels, multiple_now * base_channels, + kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, + activation_type=activation_type, norm_type=norm_type + )) + self.sequence = nn.Sequential(*sequence) + self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim) + self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim) + + def forward(self, x): + x = self.sequence(x) + x = x.view(x.size(0), -1) + return self.fc_avg(x), self.fc_var(x) + + +class SPADEGenerator(nn.Module): + def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64, + padding_mode='reflect', activation_type="LeakyReLU"): + super().__init__() + self.sx, self.sy = start_size + self.use_vae = use_vae + self.num_z_dim = num_z_dim + if use_vae: + self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy) + else: + self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1) + + sequence = [] + + multiple_now = 16 + for i in range(num_blocks - 1, -1, -1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 4) + if i != num_blocks - 1: + sequence.append(nn.Upsample(scale_factor=2)) + sequence.append(ResidualBlock( + base_channels * multiple_prev, + out_channels=base_channels * multiple_now, + padding_mode=padding_mode, + activation_type=activation_type, + norm_type="SPADE", + pre_activation=True, + additional_norm_kwargs=dict( + condition_in_channels=in_channels, base_channels=128, base_norm_type="BN", + activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0 + ) + )) + self.sequence = nn.Sequential(*sequence) + self.output_converter = nn.Sequential( + ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1, + padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"), + nn.Tanh() + ) + + def forward(self, seg, z=None): + if self.use_vae: + if z is None: + z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device) + x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy) + else: + x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy))) + for blk in self.sequence: + if isinstance(blk, ResidualBlock): + downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest') + blk.conv1.normalization.set_condition_image(downsampling_seg) + blk.conv2.normalization.set_condition_image(downsampling_seg) + if blk.learn_skip_connection: + blk.res_conv.normalization.set_condition_image(downsampling_seg) + x = blk(x) + return self.output_converter(x) + +if __name__ == '__main__': + g = SPADEGenerator(3, 3, 7, False, 256) + print(g) + print(g(torch.randn(2, 3, 256, 256)).size()) \ No newline at end of file