This commit is contained in:
Ray Wong 2020-10-25 20:46:34 +08:00
parent 0bec02bf6d
commit 8998c30c23
8 changed files with 174 additions and 55 deletions

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@ -1,7 +1,10 @@
name: VoxCeleb2Anime-TSIT
engine: TSIT
name: huawei-TSIT-1
engine: GauGAN
result_dir: ./result
max_pairs: 1500000
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
@ -16,34 +19,39 @@ handler:
random: True
images: 10
misc:
random_seed: 324
model:
generator:
_type: TSIT-Generator
_bn_to_sync_bn: True
style_in_channels: 3
content_in_channels: 3
num_blocks: 5
input_layer_type: "conv7x7"
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
@ -55,25 +63,18 @@ loss:
style_loss: False
perceptual_loss: True
weight: 1
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 0
mgc:
weight: 5
fm:
level: 1
weight: 1
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 0.0001
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
@ -87,24 +88,35 @@ data:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
buffer_size: 0
dataloader:
batch_size: 8
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
root_b: "/data/i2i/anime/your-name/faces"
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline:
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 170, 144 ]
size: [ 286, 286 ]
- RandomCrop:
size: [ 128, 128 ]
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
@ -113,22 +125,28 @@ data:
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/faces/CelebA-Asian/testA"
root_b: "/data/i2i/anime/your-name/faces"
random_pair: False
pipeline:
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in

View File

@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
loss = 0
loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i

View File

@ -3,4 +3,5 @@ import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
import model.image_translation.GauGAN
import model.image_translation.GauGAN
import model.image_translation.TSIT

View File

@ -119,6 +119,8 @@ class ResidualBlock(nn.Module):
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
conv_param['kernel_size'] = 1
conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):

View File

@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import MODEL
from model.base.module import ResidualBlock, Conv2dBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.input_layer = Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type="IN",
)
multiple_now = 1
stream_sequence = []
for i in range(1, num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
self.down_sequence = nn.ModuleList(stream_sequence)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
sequence.append(nn.Sequential(
ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="FADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
padding_mode="zeros", gamma_bias=0.0
)
),
Interpolation(scale_factor=2, mode="nearest")
))
self.up_sequence = nn.Sequential(*sequence)
self.output_layer = Conv2dBlock(
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
)
def forward(self, x, z=None):
c = self.input_layer(x)
contents = []
for blk in self.down_sequence:
c = blk(c)
contents.append(c)
if z is None:
# for image 256x256, z size: [batch_size, 1024, 2, 2]
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
for blk in self.up_sequence:
res = blk[0]
c = contents.pop()
res.conv1.normalization.set_feature(c)
res.conv2.normalization.set_feature(c)
if res.learn_skip_connection:
res.res_conv.normalization.set_feature(c)
return self.output_layer(self.up_sequence(z))
if __name__ == '__main__':
g = Generator(3, 3).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
print(g(img).size())