diff --git a/model/base/module.py b/model/base/module.py index c168d7a..a64e5f6 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -53,34 +53,29 @@ class LinearBlock(nn.Module): class Conv2dBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, bias=None, activation_type="ReLU", norm_type="NONE", - additional_norm_kwargs=None, **conv_kwargs): + additional_norm_kwargs=None, pre_activation=False, **conv_kwargs): super().__init__() self.norm_type = norm_type self.activation_type = activation_type + self.pre_activation = pre_activation - # if caller not set bias, set bias automatically. - conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias - - self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) - self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) - self.activation = _activation(activation_type) + if pre_activation: + self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) + self.activation = _activation(activation_type, inplace=False) + self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + else: + # if caller not set bias, set bias automatically. + conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias + self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) + self.activation = _activation(activation_type) def forward(self, x): + if self.pre_activation: + return self.convolution(self.activation(self.normalization(x))) return self.activation(self.normalization(self.convolution(x))) -class ReverseConv2dBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int, - activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs): - super().__init__() - self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) - self.activation = _activation(activation_type, inplace=False) - self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) - - def forward(self, x): - return self.convolution(self.activation(self.normalization(x))) - - class ResidualBlock(nn.Module): def __init__(self, in_channels, padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False, @@ -109,16 +104,15 @@ class ResidualBlock(nn.Module): self.learn_skip_connection = in_channels != out_channels - conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type, - additional_norm_kwargs=additional_norm_kwargs, - padding_mode=padding_mode) + additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation, + padding_mode=padding_mode) - self.conv1 = conv_block(in_channels, in_channels, **conv_param) - self.conv2 = conv_block(in_channels, out_channels, **conv_param) + self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param) + self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param) if self.learn_skip_connection: - self.res_conv = conv_block(in_channels, out_channels, **conv_param) + self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param) def forward(self, x): res = x if not self.learn_skip_connection else self.res_conv(x) diff --git a/model/image_translation/GauGAN.py b/model/image_translation/GauGAN.py index b65e502..3b1b7e5 100644 --- a/model/image_translation/GauGAN.py +++ b/model/image_translation/GauGAN.py @@ -1,8 +1,12 @@ +from collections import OrderedDict +from functools import partial + +import math import torch import torch.nn as nn import torch.nn.functional as F -from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock +from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock class StyleEncoder(nn.Module): @@ -33,6 +37,92 @@ class StyleEncoder(nn.Module): return self.fc_avg(x), self.fc_var(x) +class ImprovedSPADEGenerator(nn.Module): + def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4), + base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False): + super().__init__() + + assert output_size in (128, 256, 512, 1024) + self.output_size = output_size + + kernel_size = 3 + + if have_style_input: + self.style_converter = nn.Sequential( + LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"), + LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"), + ) + + base_conv = partial( + Conv2dBlock, + pre_activation=pre_activation, activation_type=activation_type, + norm_type="AdaIN" if have_style_input else "NONE", + kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode + ) + + base_residual_block = partial( + ResidualBlock, + padding_mode=padding_mode, + activation_type=activation_type, + norm_type="SPADE", + pre_activation=True, + additional_norm_kwargs=dict( + condition_in_channels=in_channels, base_channels=128, base_norm_type="BN", + activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0 + ) + ) + + sequence = OrderedDict() + channels = (2 ** 4) * base_channels + sequence["block_head"] = nn.Sequential(OrderedDict([ + ("conv_input", base_conv(in_channels=in_channels, out_channels=channels)), + ("conv_style", base_conv(in_channels=channels, out_channels=channels)), + ("res_a", base_residual_block(in_channels=channels, out_channels=channels)), + ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), + ("up", nn.Upsample(scale_factor=2, mode='nearest')) + ])) + + for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1): + channels = (2 ** (i - 1)) * base_channels + sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([ + ("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)), + ("conv_style", base_conv(in_channels=channels, out_channels=channels)), + ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), + ("up", nn.Upsample(scale_factor=2, mode='nearest')) + ])) + self.sequence = nn.Sequential(sequence) + # channels = 2*base_channels when output size is 256, 512, 1024 + # channels = 5*base_channels when output size is 128 + out_modules = OrderedDict() + out_modules["out_1"] = nn.Sequential( + Conv2dBlock( + channels, out_channels, kernel_size=5, stride=1, padding=2, + pre_activation=pre_activation, + padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE" + ), + nn.Tanh() + ) + for i in range(int(math.log(self.output_size, 2)) - 8): + channels = channels // 2 + out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([ + ("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)), + ("res_b", base_residual_block(in_channels=channels, out_channels=channels)), + ("up", nn.Upsample(scale_factor=2, mode='nearest')) + ])) + out_modules[f"out_{i + 2}"] = nn.Sequential( + Conv2dBlock( + channels, out_channels, kernel_size=5, stride=1, padding=2, + pre_activation=pre_activation, + padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE" + ), + nn.Tanh() + ) + self.out_modules = nn.ModuleDict(out_modules) + + def forward(self, seg, style=None): + pass + + class SPADEGenerator(nn.Module): def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64, padding_mode='reflect', activation_type="LeakyReLU"): @@ -89,7 +179,8 @@ class SPADEGenerator(nn.Module): x = blk(x) return self.output_converter(x) + if __name__ == '__main__': g = SPADEGenerator(3, 3, 7, False, 256) print(g) - print(g(torch.randn(2, 3, 256, 256)).size()) \ No newline at end of file + print(g(torch.randn(2, 3, 256, 256)).size()) diff --git a/model/normalization.py b/model/normalization.py deleted file mode 100644 index 9e3facf..0000000 --- a/model/normalization.py +++ /dev/null @@ -1,76 +0,0 @@ -import functools - -import torch -import torch.nn as nn - - -def select_norm_layer(norm_type): - if norm_type == "BN": - return functools.partial(nn.BatchNorm2d) - elif norm_type == "IN": - return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) - elif norm_type == "LN": - return functools.partial(LayerNorm2d, affine=True) - elif norm_type == "NONE": - return lambda num_features: nn.Identity() - elif norm_type == "AdaIN": - return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False) - else: - raise NotImplemented(f'normalization layer {norm_type} is not found') - - -class LayerNorm2d(nn.Module): - def __init__(self, num_features, eps: float = 1e-5, affine: bool = True): - super().__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - if self.affine: - self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) - self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) - self.reset_parameters() - - def reset_parameters(self): - if self.affine: - nn.init.uniform_(self.channel_gamma) - nn.init.zeros_(self.channel_beta) - - def forward(self, x): - ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) - x = (x - ln_mean) / torch.sqrt(ln_var + self.eps) - if self.affine: - return self.channel_gamma * x + self.channel_beta - return x - - def __repr__(self): - return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})" - - -class AdaptiveInstanceNorm2d(nn.Module): - def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, - affine: bool = False, track_running_stats: bool = False): - super().__init__() - self.num_features = num_features - self.affine = affine - self.track_running_stats = track_running_stats - self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats) - - self.gamma = None - self.beta = None - self.have_set_style = False - - def set_style(self, style): - style = style.view(*style.size(), 1, 1) - self.gamma, self.beta = style.chunk(2, 1) - self.have_set_style = True - - def forward(self, x): - assert self.have_set_style - x = self.norm(x) - x = self.gamma * x + self.beta - self.have_set_style = False - return x - - def __repr__(self): - return f"{self.__class__.__name__}(num_features={self.num_features}, " \ - f"affine={self.affine}, track_running_stats={self.track_running_stats})"