update SPADE
This commit is contained in:
parent
2de00d0245
commit
7b05b45156
@ -53,34 +53,29 @@ class LinearBlock(nn.Module):
|
|||||||
class Conv2dBlock(nn.Module):
|
class Conv2dBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||||
activation_type="ReLU", norm_type="NONE",
|
activation_type="ReLU", norm_type="NONE",
|
||||||
additional_norm_kwargs=None, **conv_kwargs):
|
additional_norm_kwargs=None, pre_activation=False, **conv_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_type = norm_type
|
self.norm_type = norm_type
|
||||||
self.activation_type = activation_type
|
self.activation_type = activation_type
|
||||||
|
self.pre_activation = pre_activation
|
||||||
|
|
||||||
# if caller not set bias, set bias automatically.
|
if pre_activation:
|
||||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||||
|
self.activation = _activation(activation_type, inplace=False)
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
else:
|
||||||
self.activation = _activation(activation_type)
|
# if caller not set bias, set bias automatically.
|
||||||
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||||
|
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||||
|
self.activation = _activation(activation_type)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.pre_activation:
|
||||||
|
return self.convolution(self.activation(self.normalization(x)))
|
||||||
return self.activation(self.normalization(self.convolution(x)))
|
return self.activation(self.normalization(self.convolution(x)))
|
||||||
|
|
||||||
|
|
||||||
class ReverseConv2dBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels: int, out_channels: int,
|
|
||||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
|
||||||
self.activation = _activation(activation_type, inplace=False)
|
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.convolution(self.activation(self.normalization(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_channels,
|
def __init__(self, in_channels,
|
||||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||||
@ -109,16 +104,15 @@ class ResidualBlock(nn.Module):
|
|||||||
|
|
||||||
self.learn_skip_connection = in_channels != out_channels
|
self.learn_skip_connection = in_channels != out_channels
|
||||||
|
|
||||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
|
||||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||||
additional_norm_kwargs=additional_norm_kwargs,
|
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
|
||||||
padding_mode=padding_mode)
|
padding_mode=padding_mode)
|
||||||
|
|
||||||
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
|
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
|
||||||
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
|
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
if self.learn_skip_connection:
|
if self.learn_skip_connection:
|
||||||
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
|
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
|
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||||
|
|
||||||
|
|
||||||
class StyleEncoder(nn.Module):
|
class StyleEncoder(nn.Module):
|
||||||
@ -33,6 +37,92 @@ class StyleEncoder(nn.Module):
|
|||||||
return self.fc_avg(x), self.fc_var(x)
|
return self.fc_avg(x), self.fc_var(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ImprovedSPADEGenerator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
|
||||||
|
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert output_size in (128, 256, 512, 1024)
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
kernel_size = 3
|
||||||
|
|
||||||
|
if have_style_input:
|
||||||
|
self.style_converter = nn.Sequential(
|
||||||
|
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||||
|
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
base_conv = partial(
|
||||||
|
Conv2dBlock,
|
||||||
|
pre_activation=pre_activation, activation_type=activation_type,
|
||||||
|
norm_type="AdaIN" if have_style_input else "NONE",
|
||||||
|
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
base_residual_block = partial(
|
||||||
|
ResidualBlock,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type,
|
||||||
|
norm_type="SPADE",
|
||||||
|
pre_activation=True,
|
||||||
|
additional_norm_kwargs=dict(
|
||||||
|
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||||
|
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence = OrderedDict()
|
||||||
|
channels = (2 ** 4) * base_channels
|
||||||
|
sequence["block_head"] = nn.Sequential(OrderedDict([
|
||||||
|
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
|
||||||
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
|
||||||
|
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
|
||||||
|
channels = (2 ** (i - 1)) * base_channels
|
||||||
|
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||||
|
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
|
||||||
|
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
self.sequence = nn.Sequential(sequence)
|
||||||
|
# channels = 2*base_channels when output size is 256, 512, 1024
|
||||||
|
# channels = 5*base_channels when output size is 128
|
||||||
|
out_modules = OrderedDict()
|
||||||
|
out_modules["out_1"] = nn.Sequential(
|
||||||
|
Conv2dBlock(
|
||||||
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||||
|
pre_activation=pre_activation,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||||
|
),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
for i in range(int(math.log(self.output_size, 2)) - 8):
|
||||||
|
channels = channels // 2
|
||||||
|
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||||
|
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
|
||||||
|
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||||
|
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||||
|
]))
|
||||||
|
out_modules[f"out_{i + 2}"] = nn.Sequential(
|
||||||
|
Conv2dBlock(
|
||||||
|
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||||
|
pre_activation=pre_activation,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||||
|
),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
self.out_modules = nn.ModuleDict(out_modules)
|
||||||
|
|
||||||
|
def forward(self, seg, style=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SPADEGenerator(nn.Module):
|
class SPADEGenerator(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||||
@ -89,6 +179,7 @@ class SPADEGenerator(nn.Module):
|
|||||||
x = blk(x)
|
x = blk(x)
|
||||||
return self.output_converter(x)
|
return self.output_converter(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
g = SPADEGenerator(3, 3, 7, False, 256)
|
g = SPADEGenerator(3, 3, 7, False, 256)
|
||||||
print(g)
|
print(g)
|
||||||
|
|||||||
@ -1,76 +0,0 @@
|
|||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def select_norm_layer(norm_type):
|
|
||||||
if norm_type == "BN":
|
|
||||||
return functools.partial(nn.BatchNorm2d)
|
|
||||||
elif norm_type == "IN":
|
|
||||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
||||||
elif norm_type == "LN":
|
|
||||||
return functools.partial(LayerNorm2d, affine=True)
|
|
||||||
elif norm_type == "NONE":
|
|
||||||
return lambda num_features: nn.Identity()
|
|
||||||
elif norm_type == "AdaIN":
|
|
||||||
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(nn.Module):
|
|
||||||
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.eps = eps
|
|
||||||
self.affine = affine
|
|
||||||
if self.affine:
|
|
||||||
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
||||||
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if self.affine:
|
|
||||||
nn.init.uniform_(self.channel_gamma)
|
|
||||||
nn.init.zeros_(self.channel_beta)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
|
||||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
|
||||||
if self.affine:
|
|
||||||
return self.channel_gamma * x + self.channel_beta
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveInstanceNorm2d(nn.Module):
|
|
||||||
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
|
||||||
affine: bool = False, track_running_stats: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.affine = affine
|
|
||||||
self.track_running_stats = track_running_stats
|
|
||||||
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
|
||||||
|
|
||||||
self.gamma = None
|
|
||||||
self.beta = None
|
|
||||||
self.have_set_style = False
|
|
||||||
|
|
||||||
def set_style(self, style):
|
|
||||||
style = style.view(*style.size(), 1, 1)
|
|
||||||
self.gamma, self.beta = style.chunk(2, 1)
|
|
||||||
self.have_set_style = True
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
assert self.have_set_style
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.gamma * x + self.beta
|
|
||||||
self.have_set_style = False
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
|
||||||
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
|
||||||
Loading…
Reference in New Issue
Block a user