TSIT
This commit is contained in:
parent
0bec02bf6d
commit
8998c30c23
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
name: VoxCeleb2Anime-TSIT
|
||||
engine: TSIT
|
||||
name: huawei-TSIT-1
|
||||
engine: GauGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1500000
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
@ -16,34 +19,39 @@ handler:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TSIT-Generator
|
||||
_bn_to_sync_bn: True
|
||||
style_in_channels: 3
|
||||
content_in_channels: 3
|
||||
num_blocks: 5
|
||||
input_layer_type: "conv7x7"
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
num_blocks: 7
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
num_conv: 4
|
||||
need_intermediate_feature: True
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
@ -55,25 +63,18 @@ loss:
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 1
|
||||
style:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L2'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 0
|
||||
mgc:
|
||||
weight: 5
|
||||
fm:
|
||||
level: 1
|
||||
weight: 1
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
lr: 1e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
@ -87,24 +88,35 @@ data:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
buffer_size: 0
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
@ -113,22 +125,28 @@ data:
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
|
||||
@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
|
||||
|
||||
def fm_loss(generated_features, target_features):
|
||||
num_scale = len(generated_features)
|
||||
loss = 0
|
||||
loss = torch.zeros(1, device=idist.device())
|
||||
for s_i in range(num_scale):
|
||||
for i in range(len(generated_features[s_i]) - 1):
|
||||
weight = 1 if weight_policy == "same" else 2 ** i
|
||||
|
||||
@ -4,3 +4,4 @@ import model.image_translation.UGATIT
|
||||
import model.image_translation.CycleGAN
|
||||
import model.image_translation.pix2pixHD
|
||||
import model.image_translation.GauGAN
|
||||
import model.image_translation.TSIT
|
||||
@ -119,6 +119,8 @@ class ResidualBlock(nn.Module):
|
||||
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
conv_param['kernel_size'] = 1
|
||||
conv_param['padding'] = 0
|
||||
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -0,0 +1,98 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import ResidualBlock, Conv2dBlock
|
||||
|
||||
|
||||
class Interpolation(nn.Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolation, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||
recompute_scale_factor=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||
|
||||
|
||||
@MODEL.register_module("TSIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.input_layer = Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type="IN",
|
||||
)
|
||||
multiple_now = 1
|
||||
stream_sequence = []
|
||||
for i in range(1, num_blocks + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||
ResidualBlock(
|
||||
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||
))
|
||||
self.down_sequence = nn.ModuleList(stream_sequence)
|
||||
|
||||
|
||||
sequence = []
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
sequence.append(nn.Sequential(
|
||||
ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="FADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
|
||||
padding_mode="zeros", gamma_bias=0.0
|
||||
)
|
||||
),
|
||||
Interpolation(scale_factor=2, mode="nearest")
|
||||
))
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
self.output_layer = Conv2dBlock(
|
||||
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
|
||||
)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
c = self.input_layer(x)
|
||||
contents = []
|
||||
for blk in self.down_sequence:
|
||||
c = blk(c)
|
||||
contents.append(c)
|
||||
if z is None:
|
||||
# for image 256x256, z size: [batch_size, 1024, 2, 2]
|
||||
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
|
||||
|
||||
for blk in self.up_sequence:
|
||||
res = blk[0]
|
||||
c = contents.pop()
|
||||
res.conv1.normalization.set_feature(c)
|
||||
res.conv2.normalization.set_feature(c)
|
||||
if res.learn_skip_connection:
|
||||
res.res_conv.normalization.set_feature(c)
|
||||
return self.output_layer(self.up_sequence(z))
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = Generator(3, 3).cuda()
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
print(g(img).size())
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user