diff --git a/.idea/misc.xml b/.idea/misc.xml index 1b9173d..1970339 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/.idea/raycv.iml b/.idea/raycv.iml index 9781a97..c757f31 100644 --- a/.idea/raycv.iml +++ b/.idea/raycv.iml @@ -2,7 +2,7 @@ - + diff --git a/configs/synthesizers/TSIT.yml b/configs/synthesizers/TSIT.yml index 4b40779..0710f17 100644 --- a/configs/synthesizers/TSIT.yml +++ b/configs/synthesizers/TSIT.yml @@ -1,7 +1,10 @@ -name: VoxCeleb2Anime-TSIT -engine: TSIT +name: huawei-TSIT-1 +engine: GauGAN result_dir: ./result -max_pairs: 1500000 +max_pairs: 1000000 + +misc: + random_seed: 324 handler: clear_cuda_cache: True @@ -16,34 +19,39 @@ handler: random: True images: 10 - -misc: - random_seed: 324 - model: generator: _type: TSIT-Generator - _bn_to_sync_bn: True - style_in_channels: 3 - content_in_channels: 3 - num_blocks: 5 - input_layer_type: "conv7x7" + _add_spectral_norm: True + in_channels: 3 + out_channels: 3 + num_blocks: 7 +# discriminator: +# _type: MultiScaleDiscriminator +# _add_spectral_norm: True +# num_scale: 2 +# down_sample_method: "bilinear" +# discriminator_cfg: +# _type: PatchDiscriminator +# in_channels: 3 +# base_channels: 64 +# num_conv: 4 +# need_intermediate_feature: True discriminator: - _type: MultiScaleDiscriminator - num_scale: 2 - discriminator_cfg: - _type: PatchDiscriminator - in_channels: 3 - base_channels: 64 - use_spectral: True - need_intermediate_feature: True + _type: PatchDiscriminator + _add_spectral_norm: True + in_channels: 3 + base_channels: 64 + num_conv: 4 + need_intermediate_feature: True + loss: gan: loss_type: hinge - real_label_val: 1.0 - fake_label_val: 0.0 weight: 1.0 + real_label_val: 1 + fake_label_val: 0.0 perceptual: layer_weights: "1": 0.03125 @@ -55,25 +63,18 @@ loss: style_loss: False perceptual_loss: True weight: 1 - style: - layer_weights: - "1": 0.03125 - "6": 0.0625 - "11": 0.125 - "20": 0.25 - "29": 1 - criterion: 'L2' - style_loss: True - perceptual_loss: False - weight: 0 + mgc: + weight: 5 fm: - level: 1 weight: 1 + edge: + weight: 0 + hed_pretrained_model_path: ./network-bsds500.pytorch optimizers: generator: _type: Adam - lr: 0.0001 + lr: 1e-4 betas: [ 0, 0.9 ] weight_decay: 0.0001 discriminator: @@ -87,24 +88,35 @@ data: scheduler: start_proportion: 0.5 target_lr: 0 - buffer_size: 50 + buffer_size: 0 dataloader: - batch_size: 8 + batch_size: 1 shuffle: True num_workers: 2 pin_memory: True drop_last: True dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/faces/CelebA-Asian/trainA" - root_b: "/data/i2i/anime/your-name/faces" + root_a: "/data/face2cartoon/all_face" + root_b: "/data/selfie2anime/trainB/" random_pair: True - pipeline: + pipeline_a: + - Load + - RandomCrop: + size: [ 178, 178 ] + - Resize: + size: [ 256, 256 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: - Load - Resize: - size: [ 170, 144 ] + size: [ 286, 286 ] - RandomCrop: - size: [ 128, 128 ] + size: [ 256, 256 ] - RandomHorizontalFlip - ToTensor - Normalize: @@ -113,22 +125,28 @@ data: test: which: video_dataset dataloader: - batch_size: 8 + batch_size: 1 shuffle: False num_workers: 1 pin_memory: False drop_last: False dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/faces/CelebA-Asian/testA" - root_b: "/data/i2i/anime/your-name/faces" - random_pair: False - pipeline: + root_a: "/data/face2cartoon/test/human" + root_b: "/data/face2cartoon/test/anime" + random_pair: True + pipeline_a: - Load - Resize: - size: [ 170, 144 ] - - RandomCrop: - size: [ 128, 128 ] + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + pipeline_b: + - Load + - Resize: + size: [ 256, 256 ] - ToTensor - Normalize: mean: [ 0.5, 0.5, 0.5 ] diff --git a/engine/GauGAN.py b/engine/GauGAN.py index 2c8b762..210b713 100644 --- a/engine/GauGAN.py +++ b/engine/GauGAN.py @@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel): self.gan_loss = gan_loss(config.loss.gan) self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite")) - self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline")) + self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same")) self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual)) self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in diff --git a/engine/util/loss.py b/engine/util/loss.py index 1b161ad..7fb6dc0 100644 --- a/engine/util/loss.py +++ b/engine/util/loss.py @@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy): def fm_loss(generated_features, target_features): num_scale = len(generated_features) - loss = 0 + loss = torch.zeros(1, device=idist.device()) for s_i in range(num_scale): for i in range(len(generated_features[s_i]) - 1): weight = 1 if weight_policy == "same" else 2 ** i diff --git a/model/__init__.py b/model/__init__.py index c825e4a..472a5dc 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -3,4 +3,5 @@ import model.base.normalization import model.image_translation.UGATIT import model.image_translation.CycleGAN import model.image_translation.pix2pixHD -import model.image_translation.GauGAN \ No newline at end of file +import model.image_translation.GauGAN +import model.image_translation.TSIT \ No newline at end of file diff --git a/model/base/module.py b/model/base/module.py index 91e4055..ece244d 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -119,6 +119,8 @@ class ResidualBlock(nn.Module): self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param) if self.learn_skip_connection: + conv_param['kernel_size'] = 1 + conv_param['padding'] = 0 self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param) def forward(self, x): diff --git a/model/image_translation/TSIT.py b/model/image_translation/TSIT.py index e69de29..3c4b3ce 100644 --- a/model/image_translation/TSIT.py +++ b/model/image_translation/TSIT.py @@ -0,0 +1,98 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch + +from model import MODEL +from model.base.module import ResidualBlock, Conv2dBlock + + +class Interpolation(nn.Module): + def __init__(self, scale_factor=None, mode='nearest', align_corners=None): + super(Interpolation, self).__init__() + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, + recompute_scale_factor=False) + + def __repr__(self): + return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" + + +@MODEL.register_module("TSIT-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7, + padding_mode='reflect', activation_type="LeakyReLU"): + super().__init__() + self.input_layer = Conv2dBlock( + in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode, + activation_type=activation_type, norm_type="IN", + ) + multiple_now = 1 + stream_sequence = [] + for i in range(1, num_blocks + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 4) + stream_sequence.append(nn.Sequential( + Interpolation(scale_factor=0.5, mode="nearest"), + ResidualBlock( + multiple_prev * base_channels, out_channels=multiple_now * base_channels, + padding_mode=padding_mode, activation_type=activation_type, norm_type="IN") + )) + self.down_sequence = nn.ModuleList(stream_sequence) + + + sequence = [] + multiple_now = 16 + for i in range(num_blocks - 1, -1, -1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 4) + sequence.append(nn.Sequential( + ResidualBlock( + base_channels * multiple_prev, + out_channels=base_channels * multiple_now, + padding_mode=padding_mode, + activation_type=activation_type, + norm_type="FADE", + pre_activation=True, + additional_norm_kwargs=dict( + condition_in_channels=base_channels * multiple_prev, base_norm_type="BN", + padding_mode="zeros", gamma_bias=0.0 + ) + ), + Interpolation(scale_factor=2, mode="nearest") + )) + self.up_sequence = nn.Sequential(*sequence) + + self.output_layer = Conv2dBlock( + base_channels, out_channels, kernel_size=3, stride=1, padding=1, + padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE" + ) + + def forward(self, x, z=None): + c = self.input_layer(x) + contents = [] + for blk in self.down_sequence: + c = blk(c) + contents.append(c) + if z is None: + # for image 256x256, z size: [batch_size, 1024, 2, 2] + z = torch.randn(size=contents[-1].size(), device=contents[-1].device) + + for blk in self.up_sequence: + res = blk[0] + c = contents.pop() + res.conv1.normalization.set_feature(c) + res.conv2.normalization.set_feature(c) + if res.learn_skip_connection: + res.res_conv.normalization.set_feature(c) + return self.output_layer(self.up_sequence(z)) + +if __name__ == '__main__': + g = Generator(3, 3).cuda() + img = torch.randn(2, 3, 256, 256).cuda() + print(g(img).size()) + +