TSIT
This commit is contained in:
parent
0bec02bf6d
commit
8998c30c23
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="TestRunnerService">
|
<component name="TestRunnerService">
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
name: VoxCeleb2Anime-TSIT
|
name: huawei-TSIT-1
|
||||||
engine: TSIT
|
engine: GauGAN
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 1500000
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
handler:
|
handler:
|
||||||
clear_cuda_cache: True
|
clear_cuda_cache: True
|
||||||
@ -16,34 +19,39 @@ handler:
|
|||||||
random: True
|
random: True
|
||||||
images: 10
|
images: 10
|
||||||
|
|
||||||
|
|
||||||
misc:
|
|
||||||
random_seed: 324
|
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
_type: TSIT-Generator
|
_type: TSIT-Generator
|
||||||
_bn_to_sync_bn: True
|
_add_spectral_norm: True
|
||||||
style_in_channels: 3
|
in_channels: 3
|
||||||
content_in_channels: 3
|
out_channels: 3
|
||||||
num_blocks: 5
|
num_blocks: 7
|
||||||
input_layer_type: "conv7x7"
|
# discriminator:
|
||||||
|
# _type: MultiScaleDiscriminator
|
||||||
|
# _add_spectral_norm: True
|
||||||
|
# num_scale: 2
|
||||||
|
# down_sample_method: "bilinear"
|
||||||
|
# discriminator_cfg:
|
||||||
|
# _type: PatchDiscriminator
|
||||||
|
# in_channels: 3
|
||||||
|
# base_channels: 64
|
||||||
|
# num_conv: 4
|
||||||
|
# need_intermediate_feature: True
|
||||||
discriminator:
|
discriminator:
|
||||||
_type: MultiScaleDiscriminator
|
_type: PatchDiscriminator
|
||||||
num_scale: 2
|
_add_spectral_norm: True
|
||||||
discriminator_cfg:
|
in_channels: 3
|
||||||
_type: PatchDiscriminator
|
base_channels: 64
|
||||||
in_channels: 3
|
num_conv: 4
|
||||||
base_channels: 64
|
need_intermediate_feature: True
|
||||||
use_spectral: True
|
|
||||||
need_intermediate_feature: True
|
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
gan:
|
gan:
|
||||||
loss_type: hinge
|
loss_type: hinge
|
||||||
real_label_val: 1.0
|
|
||||||
fake_label_val: 0.0
|
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
|
real_label_val: 1
|
||||||
|
fake_label_val: 0.0
|
||||||
perceptual:
|
perceptual:
|
||||||
layer_weights:
|
layer_weights:
|
||||||
"1": 0.03125
|
"1": 0.03125
|
||||||
@ -55,25 +63,18 @@ loss:
|
|||||||
style_loss: False
|
style_loss: False
|
||||||
perceptual_loss: True
|
perceptual_loss: True
|
||||||
weight: 1
|
weight: 1
|
||||||
style:
|
mgc:
|
||||||
layer_weights:
|
weight: 5
|
||||||
"1": 0.03125
|
|
||||||
"6": 0.0625
|
|
||||||
"11": 0.125
|
|
||||||
"20": 0.25
|
|
||||||
"29": 1
|
|
||||||
criterion: 'L2'
|
|
||||||
style_loss: True
|
|
||||||
perceptual_loss: False
|
|
||||||
weight: 0
|
|
||||||
fm:
|
fm:
|
||||||
level: 1
|
|
||||||
weight: 1
|
weight: 1
|
||||||
|
edge:
|
||||||
|
weight: 0
|
||||||
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
|
|
||||||
optimizers:
|
optimizers:
|
||||||
generator:
|
generator:
|
||||||
_type: Adam
|
_type: Adam
|
||||||
lr: 0.0001
|
lr: 1e-4
|
||||||
betas: [ 0, 0.9 ]
|
betas: [ 0, 0.9 ]
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
discriminator:
|
discriminator:
|
||||||
@ -87,24 +88,35 @@ data:
|
|||||||
scheduler:
|
scheduler:
|
||||||
start_proportion: 0.5
|
start_proportion: 0.5
|
||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 0
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
drop_last: True
|
drop_last: True
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
root_a: "/data/face2cartoon/all_face"
|
||||||
root_b: "/data/i2i/anime/your-name/faces"
|
root_b: "/data/selfie2anime/trainB/"
|
||||||
random_pair: True
|
random_pair: True
|
||||||
pipeline:
|
pipeline_a:
|
||||||
|
- Load
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 178, 178 ]
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
size: [ 170, 144 ]
|
size: [ 286, 286 ]
|
||||||
- RandomCrop:
|
- RandomCrop:
|
||||||
size: [ 128, 128 ]
|
size: [ 256, 256 ]
|
||||||
- RandomHorizontalFlip
|
- RandomHorizontalFlip
|
||||||
- ToTensor
|
- ToTensor
|
||||||
- Normalize:
|
- Normalize:
|
||||||
@ -113,22 +125,28 @@ data:
|
|||||||
test:
|
test:
|
||||||
which: video_dataset
|
which: video_dataset
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
shuffle: False
|
shuffle: False
|
||||||
num_workers: 1
|
num_workers: 1
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
root_a: "/data/face2cartoon/test/human"
|
||||||
root_b: "/data/i2i/anime/your-name/faces"
|
root_b: "/data/face2cartoon/test/anime"
|
||||||
random_pair: False
|
random_pair: True
|
||||||
pipeline:
|
pipeline_a:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
size: [ 170, 144 ]
|
size: [ 256, 256 ]
|
||||||
- RandomCrop:
|
- ToTensor
|
||||||
size: [ 128, 128 ]
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
pipeline_b:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 256, 256 ]
|
||||||
- ToTensor
|
- ToTensor
|
||||||
- Normalize:
|
- Normalize:
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
|
|||||||
|
|
||||||
self.gan_loss = gan_loss(config.loss.gan)
|
self.gan_loss = gan_loss(config.loss.gan)
|
||||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
|
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||||
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||||
|
|
||||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||||
|
|||||||
@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
|
|||||||
|
|
||||||
def fm_loss(generated_features, target_features):
|
def fm_loss(generated_features, target_features):
|
||||||
num_scale = len(generated_features)
|
num_scale = len(generated_features)
|
||||||
loss = 0
|
loss = torch.zeros(1, device=idist.device())
|
||||||
for s_i in range(num_scale):
|
for s_i in range(num_scale):
|
||||||
for i in range(len(generated_features[s_i]) - 1):
|
for i in range(len(generated_features[s_i]) - 1):
|
||||||
weight = 1 if weight_policy == "same" else 2 ** i
|
weight = 1 if weight_policy == "same" else 2 ** i
|
||||||
|
|||||||
@ -3,4 +3,5 @@ import model.base.normalization
|
|||||||
import model.image_translation.UGATIT
|
import model.image_translation.UGATIT
|
||||||
import model.image_translation.CycleGAN
|
import model.image_translation.CycleGAN
|
||||||
import model.image_translation.pix2pixHD
|
import model.image_translation.pix2pixHD
|
||||||
import model.image_translation.GauGAN
|
import model.image_translation.GauGAN
|
||||||
|
import model.image_translation.TSIT
|
||||||
@ -119,6 +119,8 @@ class ResidualBlock(nn.Module):
|
|||||||
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
if self.learn_skip_connection:
|
if self.learn_skip_connection:
|
||||||
|
conv_param['kernel_size'] = 1
|
||||||
|
conv_param['padding'] = 0
|
||||||
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -0,0 +1,98 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model import MODEL
|
||||||
|
from model.base.module import ResidualBlock, Conv2dBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolation(nn.Module):
|
||||||
|
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||||
|
super(Interpolation, self).__init__()
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||||
|
recompute_scale_factor=False)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("TSIT-Generator")
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
|
||||||
|
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||||
|
super().__init__()
|
||||||
|
self.input_layer = Conv2dBlock(
|
||||||
|
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type, norm_type="IN",
|
||||||
|
)
|
||||||
|
multiple_now = 1
|
||||||
|
stream_sequence = []
|
||||||
|
for i in range(1, num_blocks + 1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 4)
|
||||||
|
stream_sequence.append(nn.Sequential(
|
||||||
|
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||||
|
ResidualBlock(
|
||||||
|
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
|
||||||
|
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||||
|
))
|
||||||
|
self.down_sequence = nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
|
|
||||||
|
sequence = []
|
||||||
|
multiple_now = 16
|
||||||
|
for i in range(num_blocks - 1, -1, -1):
|
||||||
|
multiple_prev = multiple_now
|
||||||
|
multiple_now = min(2 ** i, 2 ** 4)
|
||||||
|
sequence.append(nn.Sequential(
|
||||||
|
ResidualBlock(
|
||||||
|
base_channels * multiple_prev,
|
||||||
|
out_channels=base_channels * multiple_now,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
activation_type=activation_type,
|
||||||
|
norm_type="FADE",
|
||||||
|
pre_activation=True,
|
||||||
|
additional_norm_kwargs=dict(
|
||||||
|
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
|
||||||
|
padding_mode="zeros", gamma_bias=0.0
|
||||||
|
)
|
||||||
|
),
|
||||||
|
Interpolation(scale_factor=2, mode="nearest")
|
||||||
|
))
|
||||||
|
self.up_sequence = nn.Sequential(*sequence)
|
||||||
|
|
||||||
|
self.output_layer = Conv2dBlock(
|
||||||
|
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||||
|
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, z=None):
|
||||||
|
c = self.input_layer(x)
|
||||||
|
contents = []
|
||||||
|
for blk in self.down_sequence:
|
||||||
|
c = blk(c)
|
||||||
|
contents.append(c)
|
||||||
|
if z is None:
|
||||||
|
# for image 256x256, z size: [batch_size, 1024, 2, 2]
|
||||||
|
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
|
||||||
|
|
||||||
|
for blk in self.up_sequence:
|
||||||
|
res = blk[0]
|
||||||
|
c = contents.pop()
|
||||||
|
res.conv1.normalization.set_feature(c)
|
||||||
|
res.conv2.normalization.set_feature(c)
|
||||||
|
if res.learn_skip_connection:
|
||||||
|
res.res_conv.normalization.set_feature(c)
|
||||||
|
return self.output_layer(self.up_sequence(z))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
g = Generator(3, 3).cuda()
|
||||||
|
img = torch.randn(2, 3, 256, 256).cuda()
|
||||||
|
print(g(img).size())
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user