add GauGAN
This commit is contained in:
parent
06b2abd19a
commit
6070f08835
@ -20,13 +20,13 @@ def _normalization(norm, num_features, additional_kwargs=None):
|
||||
return NORMALIZATION.build_with(kwargs)
|
||||
|
||||
|
||||
def _activation(activation):
|
||||
def _activation(activation, inplace=True):
|
||||
if activation == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
elif activation == "ReLU":
|
||||
return nn.ReLU(inplace=True)
|
||||
return nn.ReLU(inplace=inplace)
|
||||
elif activation == "LeakyReLU":
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
@ -74,7 +74,7 @@ class ReverseConv2dBlock(nn.Module):
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
self.activation = _activation(activation_type, inplace=False)
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
@ -84,7 +84,7 @@ class ReverseConv2dBlock(nn.Module):
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||
out_channels=None, out_activation_type=None):
|
||||
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
|
||||
"""
|
||||
Residual Conv Block
|
||||
:param in_channels:
|
||||
@ -110,15 +110,15 @@ class ResidualBlock(nn.Module):
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
|
||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||
additional_norm_kwargs=additional_norm_kwargs,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
|
||||
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
|
||||
@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||
|
||||
@NORMALIZATION.register_module("ADE")
|
||||
class AdaptiveDenormalization(nn.Module):
|
||||
def __init__(self, num_features, base_norm_type="BN"):
|
||||
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.base_norm_type = base_norm_type
|
||||
self.norm = self.base_norm(num_features)
|
||||
self.gamma = None
|
||||
self.gamma_bias = gamma_bias
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def base_norm(self, num_features):
|
||||
if self.base_norm_type == "IN":
|
||||
return nn.InstanceNorm2d(num_features)
|
||||
return nn.InstanceNorm2d(num_features, affine=False)
|
||||
elif self.base_norm_type == "BN":
|
||||
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||
|
||||
@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
x = (self.gamma + self.gamma_bias) * x + self.beta
|
||||
self.have_set_condition = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"base_norm_type={self.base_norm_type})"
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
# f"base_norm_type={self.base_norm_type})"
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaIN")
|
||||
@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||
|
||||
@NORMALIZATION.register_module("FADE")
|
||||
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
def __init__(self, num_features: int, condition_in_channels,
|
||||
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
@NORMALIZATION.register_module("SPADE")
|
||||
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
|
||||
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
multiple_now = 0
|
||||
max_multiple = 3
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sequence(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc_avg(x), self.fc_var(x)
|
||||
|
||||
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.sx, self.sy = start_size
|
||||
self.use_vae = use_vae
|
||||
self.num_z_dim = num_z_dim
|
||||
if use_vae:
|
||||
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
|
||||
else:
|
||||
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
|
||||
|
||||
sequence = []
|
||||
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
if i != num_blocks - 1:
|
||||
sequence.append(nn.Upsample(scale_factor=2))
|
||||
sequence.append(ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="SPADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = nn.Sequential(
|
||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
if z is None:
|
||||
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
|
||||
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
|
||||
else:
|
||||
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
|
||||
for blk in self.sequence:
|
||||
if isinstance(blk, ResidualBlock):
|
||||
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
||||
blk.conv1.normalization.set_condition_image(downsampling_seg)
|
||||
blk.conv2.normalization.set_condition_image(downsampling_seg)
|
||||
if blk.learn_skip_connection:
|
||||
blk.res_conv.normalization.set_condition_image(downsampling_seg)
|
||||
x = blk(x)
|
||||
return self.output_converter(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = SPADEGenerator(3, 3, 7, False, 256)
|
||||
print(g)
|
||||
print(g(torch.randn(2, 3, 256, 256)).size())
|
||||
Loading…
Reference in New Issue
Block a user