diff --git a/.idea/misc.xml b/.idea/misc.xml
index 1b9173d..1970339 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/.idea/raycv.iml b/.idea/raycv.iml
index 9781a97..c757f31 100644
--- a/.idea/raycv.iml
+++ b/.idea/raycv.iml
@@ -2,7 +2,7 @@
-
+
diff --git a/configs/synthesizers/TSIT.yml b/configs/synthesizers/TSIT.yml
index 4b40779..0710f17 100644
--- a/configs/synthesizers/TSIT.yml
+++ b/configs/synthesizers/TSIT.yml
@@ -1,7 +1,10 @@
-name: VoxCeleb2Anime-TSIT
-engine: TSIT
+name: huawei-TSIT-1
+engine: GauGAN
result_dir: ./result
-max_pairs: 1500000
+max_pairs: 1000000
+
+misc:
+ random_seed: 324
handler:
clear_cuda_cache: True
@@ -16,34 +19,39 @@ handler:
random: True
images: 10
-
-misc:
- random_seed: 324
-
model:
generator:
_type: TSIT-Generator
- _bn_to_sync_bn: True
- style_in_channels: 3
- content_in_channels: 3
- num_blocks: 5
- input_layer_type: "conv7x7"
+ _add_spectral_norm: True
+ in_channels: 3
+ out_channels: 3
+ num_blocks: 7
+# discriminator:
+# _type: MultiScaleDiscriminator
+# _add_spectral_norm: True
+# num_scale: 2
+# down_sample_method: "bilinear"
+# discriminator_cfg:
+# _type: PatchDiscriminator
+# in_channels: 3
+# base_channels: 64
+# num_conv: 4
+# need_intermediate_feature: True
discriminator:
- _type: MultiScaleDiscriminator
- num_scale: 2
- discriminator_cfg:
- _type: PatchDiscriminator
- in_channels: 3
- base_channels: 64
- use_spectral: True
- need_intermediate_feature: True
+ _type: PatchDiscriminator
+ _add_spectral_norm: True
+ in_channels: 3
+ base_channels: 64
+ num_conv: 4
+ need_intermediate_feature: True
+
loss:
gan:
loss_type: hinge
- real_label_val: 1.0
- fake_label_val: 0.0
weight: 1.0
+ real_label_val: 1
+ fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
@@ -55,25 +63,18 @@ loss:
style_loss: False
perceptual_loss: True
weight: 1
- style:
- layer_weights:
- "1": 0.03125
- "6": 0.0625
- "11": 0.125
- "20": 0.25
- "29": 1
- criterion: 'L2'
- style_loss: True
- perceptual_loss: False
- weight: 0
+ mgc:
+ weight: 5
fm:
- level: 1
weight: 1
+ edge:
+ weight: 0
+ hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
- lr: 0.0001
+ lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
@@ -87,24 +88,35 @@ data:
scheduler:
start_proportion: 0.5
target_lr: 0
- buffer_size: 50
+ buffer_size: 0
dataloader:
- batch_size: 8
+ batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
- root_a: "/data/i2i/faces/CelebA-Asian/trainA"
- root_b: "/data/i2i/anime/your-name/faces"
+ root_a: "/data/face2cartoon/all_face"
+ root_b: "/data/selfie2anime/trainB/"
random_pair: True
- pipeline:
+ pipeline_a:
+ - Load
+ - RandomCrop:
+ size: [ 178, 178 ]
+ - Resize:
+ size: [ 256, 256 ]
+ - RandomHorizontalFlip
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ pipeline_b:
- Load
- Resize:
- size: [ 170, 144 ]
+ size: [ 286, 286 ]
- RandomCrop:
- size: [ 128, 128 ]
+ size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
@@ -113,22 +125,28 @@ data:
test:
which: video_dataset
dataloader:
- batch_size: 8
+ batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
- root_a: "/data/i2i/faces/CelebA-Asian/testA"
- root_b: "/data/i2i/anime/your-name/faces"
- random_pair: False
- pipeline:
+ root_a: "/data/face2cartoon/test/human"
+ root_b: "/data/face2cartoon/test/anime"
+ random_pair: True
+ pipeline_a:
- Load
- Resize:
- size: [ 170, 144 ]
- - RandomCrop:
- size: [ 128, 128 ]
+ size: [ 256, 256 ]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ pipeline_b:
+ - Load
+ - Resize:
+ size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
diff --git a/engine/GauGAN.py b/engine/GauGAN.py
index 2c8b762..210b713 100644
--- a/engine/GauGAN.py
+++ b/engine/GauGAN.py
@@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
- self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
+ self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
diff --git a/engine/util/loss.py b/engine/util/loss.py
index 1b161ad..7fb6dc0 100644
--- a/engine/util/loss.py
+++ b/engine/util/loss.py
@@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
- loss = 0
+ loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i
diff --git a/model/__init__.py b/model/__init__.py
index c825e4a..472a5dc 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -3,4 +3,5 @@ import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
-import model.image_translation.GauGAN
\ No newline at end of file
+import model.image_translation.GauGAN
+import model.image_translation.TSIT
\ No newline at end of file
diff --git a/model/base/module.py b/model/base/module.py
index 91e4055..ece244d 100644
--- a/model/base/module.py
+++ b/model/base/module.py
@@ -119,6 +119,8 @@ class ResidualBlock(nn.Module):
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
+ conv_param['kernel_size'] = 1
+ conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):
diff --git a/model/image_translation/TSIT.py b/model/image_translation/TSIT.py
index e69de29..3c4b3ce 100644
--- a/model/image_translation/TSIT.py
+++ b/model/image_translation/TSIT.py
@@ -0,0 +1,98 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+from model import MODEL
+from model.base.module import ResidualBlock, Conv2dBlock
+
+
+class Interpolation(nn.Module):
+ def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
+ super(Interpolation, self).__init__()
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
+ recompute_scale_factor=False)
+
+ def __repr__(self):
+ return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
+
+
+@MODEL.register_module("TSIT-Generator")
+class Generator(nn.Module):
+ def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
+ padding_mode='reflect', activation_type="LeakyReLU"):
+ super().__init__()
+ self.input_layer = Conv2dBlock(
+ in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
+ activation_type=activation_type, norm_type="IN",
+ )
+ multiple_now = 1
+ stream_sequence = []
+ for i in range(1, num_blocks + 1):
+ multiple_prev = multiple_now
+ multiple_now = min(2 ** i, 2 ** 4)
+ stream_sequence.append(nn.Sequential(
+ Interpolation(scale_factor=0.5, mode="nearest"),
+ ResidualBlock(
+ multiple_prev * base_channels, out_channels=multiple_now * base_channels,
+ padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
+ ))
+ self.down_sequence = nn.ModuleList(stream_sequence)
+
+
+ sequence = []
+ multiple_now = 16
+ for i in range(num_blocks - 1, -1, -1):
+ multiple_prev = multiple_now
+ multiple_now = min(2 ** i, 2 ** 4)
+ sequence.append(nn.Sequential(
+ ResidualBlock(
+ base_channels * multiple_prev,
+ out_channels=base_channels * multiple_now,
+ padding_mode=padding_mode,
+ activation_type=activation_type,
+ norm_type="FADE",
+ pre_activation=True,
+ additional_norm_kwargs=dict(
+ condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
+ padding_mode="zeros", gamma_bias=0.0
+ )
+ ),
+ Interpolation(scale_factor=2, mode="nearest")
+ ))
+ self.up_sequence = nn.Sequential(*sequence)
+
+ self.output_layer = Conv2dBlock(
+ base_channels, out_channels, kernel_size=3, stride=1, padding=1,
+ padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
+ )
+
+ def forward(self, x, z=None):
+ c = self.input_layer(x)
+ contents = []
+ for blk in self.down_sequence:
+ c = blk(c)
+ contents.append(c)
+ if z is None:
+ # for image 256x256, z size: [batch_size, 1024, 2, 2]
+ z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
+
+ for blk in self.up_sequence:
+ res = blk[0]
+ c = contents.pop()
+ res.conv1.normalization.set_feature(c)
+ res.conv2.normalization.set_feature(c)
+ if res.learn_skip_connection:
+ res.res_conv.normalization.set_feature(c)
+ return self.output_layer(self.up_sequence(z))
+
+if __name__ == '__main__':
+ g = Generator(3, 3).cuda()
+ img = torch.randn(2, 3, 256, 256).cuda()
+ print(g(img).size())
+
+