Compare commits
2 Commits
9e8e73c988
...
7a85499edf
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a85499edf | |||
| 0841d03b3c |
177
model/GAN/TAHG.py
Normal file
177
model/GAN/TAHG.py
Normal file
@ -0,0 +1,177 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
|
||||
vgg19_layers=(0, 5, 10, 19)):
|
||||
super().__init__()
|
||||
self.vgg19_layers = vgg19_layers
|
||||
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
|
||||
self.vgg19.requires_grad_(False)
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.conv = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
|
||||
padding_mode=padding_mode, bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
) for i in range(1, 4)
|
||||
])
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def fixed_style_features(self, x):
|
||||
features = []
|
||||
for i in range(len(self.vgg19)):
|
||||
x = self.vgg19[i](x)
|
||||
if i in self.vgg19_layers:
|
||||
features.append(x)
|
||||
return features
|
||||
|
||||
def forward(self, x):
|
||||
fsf = self.fixed_style_features(x)
|
||||
x = self.conv0(x)
|
||||
for i, l in enumerate(self.conv):
|
||||
x = l(torch.cat([x, fsf[i]], dim=1))
|
||||
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
|
||||
x = self.conv1x1(x)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class ContentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=True),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_conv(x)
|
||||
x = self.encoder(x)
|
||||
x = self.resnet(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
|
||||
super(Decoder, self).__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
|
||||
padding=2, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.decoder(x)
|
||||
x = self.end_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.start_fc = nn.Sequential(
|
||||
nn.Linear(in_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fcs = nn.Sequential(*[
|
||||
nn.Sequential(
|
||||
nn.Linear(base_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
) for _ in range(n_blocks - 2)
|
||||
])
|
||||
self.end_fc = nn.Sequential(
|
||||
nn.Linear(base_features, out_features),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_fc(x)
|
||||
x = self.fcs(x)
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_res = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoder(style_img)))
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
@ -60,34 +60,31 @@ class GANImageBuffer(object):
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN", use_dropout=False):
|
||||
norm_type="IN"):
|
||||
super(ResGenerator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
@ -115,7 +112,7 @@ class ResGenerator(nn.Module):
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
import torch
|
||||
|
||||
|
||||
def select_norm_layer(norm_type):
|
||||
@ -7,7 +8,69 @@ def select_norm_layer(norm_type):
|
||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == "IN":
|
||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == "LN":
|
||||
return functools.partial(LayerNorm2d, affine=True)
|
||||
elif norm_type == "NONE":
|
||||
return lambda x: nn.Identity()
|
||||
return lambda num_features: nn.Identity()
|
||||
elif norm_type == "AdaIN":
|
||||
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
||||
else:
|
||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.affine:
|
||||
nn.init.uniform_(self.channel_gamma)
|
||||
nn.init.zeros_(self.channel_beta)
|
||||
|
||||
def forward(self, x):
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||
print(x.size())
|
||||
if self.affine:
|
||||
return self.channel_gamma * x + self.channel_beta
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm2d(nn.Module):
|
||||
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = False, track_running_stats: bool = False):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_style = False
|
||||
|
||||
def set_style(self, style):
|
||||
style = style.view(*style.size(), 1, 1)
|
||||
self.gamma, self.beta = style.chunk(2, 1)
|
||||
self.have_set_style = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_style
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
self.have_set_style = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user