diff --git a/configs/synthesizers/TSIT.yml b/configs/synthesizers/TSIT.yml new file mode 100644 index 0000000..a6edfbf --- /dev/null +++ b/configs/synthesizers/TSIT.yml @@ -0,0 +1,146 @@ +name: self2anime-TSIT +engine: TSIT +result_dir: ./result +max_pairs: 1500000 + +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 2 # log image `image` times per epoch + + +misc: + random_seed: 324 + +model: + generator: + _type: TSIT-Generator + _bn_to_sync_bn: True + style_in_channels: 3 + content_in_channels: 3 + num_blocks: 5 + input_layer_type: "conv7x7" + discriminator: + _type: MultiScaleDiscriminator + num_scale: 2 + discriminator_cfg: + _type: PatchDiscriminator + in_channels: 3 + base_channels: 64 + use_spectral: True + need_intermediate_feature: True + +loss: + gan: + loss_type: hinge + real_label_val: 1.0 + fake_label_val: 0.0 + weight: 1.0 + perceptual: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L1' + style_loss: False + perceptual_loss: True + weight: 1 + style: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L2' + style_loss: True + perceptual_loss: False + weight: 0 + fm: + level: 1 + weight: 1 + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 4e-4 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + buffer_size: 50 + dataloader: + batch_size: 1 + shuffle: True + num_workers: 2 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/trainA" + root_b: "/data/i2i/VoxCeleb2Anime/trainB" + edges_path: "/data/i2i/VoxCeleb2Anime/edges" + landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" + edge_type: "landmark_hed" + size: [ 128, 128 ] + random_pair: True + pipeline: + - Load + - Resize: + size: [ 128, 128 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + dataloader: + batch_size: 8 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/testA" + root_b: "/data/i2i/VoxCeleb2Anime/testB" + edges_path: "/data/i2i/VoxCeleb2Anime/edges" + landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" + edge_type: "landmark_hed" + random_pair: False + size: [ 128, 128 ] + pipeline: + - Load + - Resize: + size: [ 128, 128 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + video_dataset: + _type: SingleFolderDataset + root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" + with_path: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/engine/TSIT.py b/engine/TSIT.py new file mode 100644 index 0000000..6ed2eeb --- /dev/null +++ b/engine/TSIT.py @@ -0,0 +1,106 @@ +from itertools import chain + +import ignite.distributed as idist +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from engine.base.i2i import EngineKernel, run_kernel +from engine.util.build import build_model +from loss.I2I.perceptual_loss import PerceptualLoss +from loss.gan import GANLoss +from model.weight_init import generation_init_weights + + +class TSITEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) + perceptual_loss_cfg.pop("weight") + self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + + self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() + + def build_models(self) -> (dict, dict): + generators = dict( + main=build_model(self.config.model.generator) + ) + discriminators = dict( + b=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["b"]) + self.logger.debug(generators["main"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_after_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + with torch.set_grad_enabled(not inference): + fake = dict( + b=self.generators["main"](content_img=batch["a"], style_img=batch["b"]) + ) + return fake + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) + loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight + for phase in "b": + pred_fake = self.discriminators[phase](generated[phase]) + loss[f"gan_{phase}"] = 0 + for sub_pred_fake in pred_fake: + # last output is actual prediction + loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True) + + if self.config.loss.fm.weight > 0 and phase == "b": + pred_real = self.discriminators[phase](batch[phase]) + loss_fm = 0 + num_scale_discriminator = len(pred_fake) + for i in range(num_scale_discriminator): + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): + loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator + loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + for phase in self.discriminators.keys(): + pred_real = self.discriminators[phase](batch[phase]) + pred_fake = self.discriminators[phase](generated[phase].detach()) + loss[f"gan_{phase}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + return dict( + b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()] + ) + + +def run(task, config, _): + kernel = TSITEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/model/GAN/TSIT.py b/model/GAN/TSIT.py new file mode 100644 index 0000000..1a4f429 --- /dev/null +++ b/model/GAN/TSIT.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model import MODEL +from model.normalization import AdaptiveInstanceNorm2d +from model.normalization import select_norm_layer + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None, + use_spectral=True): + super().__init__() + self.padding_mode = padding_mode + self.use_bias = use_bias + self.use_spectral = use_spectral + if use_bias is None: + # Only for IN, use bias since it does not have affine parameters. + self.use_bias = norm_type == "IN" + norm_layer = select_norm_layer(norm_type) + self.main = nn.Sequential( + self.conv_block(in_channels, in_channels), + norm_layer(in_channels), + nn.LeakyReLU(0.2, inplace=True), + self.conv_block(in_channels, out_channels), + norm_layer(out_channels), + nn.LeakyReLU(0.2, inplace=True), + ) + self.skip = nn.Sequential( + self.conv_block(in_channels, out_channels, padding=0, kernel_size=1), + norm_layer(out_channels), + nn.LeakyReLU(0.2, inplace=True), + ) + + def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1): + conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, + padding_mode=self.padding_mode, bias=self.use_bias) + if self.use_spectral: + return nn.utils.spectral_norm(conv) + else: + return conv + + def forward(self, x): + return self.main(x) + self.skip(x) + + +class Interpolation(nn.Module): + def __init__(self, scale_factor=None, mode='nearest', align_corners=None): + super(Interpolation, self).__init__() + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, + recompute_scale_factor=False) + + def __repr__(self): + return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})" + + +class FADE(nn.Module): + def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True): + super().__init__() + self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) + self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, + padding_mode="zeros") + self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, + padding_mode="zeros") + + def forward(self, x, feature): + alpha = self.alpha_conv(feature) + beta = self.beta_conv(feature) + x = self.bn(x) + return alpha * x + beta + + +class FADEResBlock(nn.Module): + def __init__(self, use_spectral, features_channels, in_channels, out_channels): + super().__init__() + self.main = nn.Sequential( + FADE(use_spectral, features_channels, in_channels), + nn.LeakyReLU(0.2, inplace=True), + conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1), + FADE(use_spectral, features_channels, in_channels), + nn.LeakyReLU(0.2, inplace=True), + conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = nn.Sequential( + FADE(use_spectral, features_channels, in_channels), + nn.LeakyReLU(0.2, inplace=True), + conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0), + ) + self.up_sample = Interpolation(2, mode="nearest") + + @staticmethod + def forward_with_fade(module, x, feature): + for layer in module: + if layer.__class__.__name__ == "FADE": + x = layer(x, feature) + else: + x = layer(x) + return x + + def forward(self, x, feature): + out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature) + return self.up_sample(out) + + +def conv_block(use_spectral, in_channels, out_channels, **kwargs): + conv = nn.Conv2d(in_channels, out_channels, **kwargs) + return nn.utils.spectral_norm(conv) if use_spectral else conv + + +@MODEL.register_module("TSIT-Generator") +class TSITGenerator(nn.Module): + def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3, + out_channels=3, use_spectral=True, input_layer_type="conv1x1"): + super().__init__() + self.num_blocks = num_blocks + self.base_channels = base_channels + self.use_spectral = use_spectral + + self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type) + self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type) + self.content_stream = self.build_stream() + self.style_stream = self.build_stream() + self.generator = self.build_generator() + self.end_conv = nn.Sequential( + conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"), + nn.Tanh() + ) + + def build_generator(self): + stream_sequence = [] + multiple_now = min(2 ** self.num_blocks, 2 ** 4) + for i in range(1, self.num_blocks + 1): + m = self.num_blocks - i + multiple_prev = multiple_now + multiple_now = min(2 ** m, 2 ** 4) + stream_sequence.append(nn.Sequential( + AdaptiveInstanceNorm2d(multiple_prev * self.base_channels), + FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels, + multiple_now * self.base_channels) + )) + return nn.ModuleList(stream_sequence) + + def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"): + if input_layer_type == "conv7x7": + return nn.Sequential( + conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1, + padding_mode="zeros", padding=3, bias=True), + select_norm_layer("IN")(out_channels), + nn.ReLU(inplace=True) + ) + elif input_layer_type == "conv1x1": + return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + raise NotImplemented + + def build_stream(self): + multiple_now = 1 + stream_sequence = [] + for i in range(1, self.num_blocks + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** i, 2 ** 4) + stream_sequence.append(nn.Sequential( + Interpolation(scale_factor=0.5, mode="nearest"), + ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels, + use_spectral=self.use_spectral) + )) + return nn.ModuleList(stream_sequence) + + def forward(self, content_img, style_img): + c = self.content_input_layer(content_img) + s = self.style_input_layer(style_img) + content_features = [] + style_features = [] + for i in range(self.num_blocks): + s = self.style_stream[i](s) + c = self.content_stream[i](c) + content_features.append(c) + style_features.append(s) + z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) + + for i in range(self.num_blocks): + m = - i - 1 + layer = self.generator[i] + layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1)) + z = layer[0](z) + z = layer[1](z, content_features[m]) + return self.end_conv(z) diff --git a/model/__init__.py b/model/__init__.py index b3533b4..53029b2 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -4,3 +4,4 @@ import model.GAN.TAFG import model.GAN.UGATIT import model.GAN.wrapper import model.GAN.base +import model.GAN.TSIT \ No newline at end of file diff --git a/model/normalization.py b/model/normalization.py index fd7a2d8..a5cbcf2 100644 --- a/model/normalization.py +++ b/model/normalization.py @@ -1,6 +1,7 @@ -import torch.nn as nn import functools + import torch +import torch.nn as nn def select_norm_layer(norm_type): diff --git a/run.sh b/run.sh index c93222d..77eddce 100644 --- a/run.sh +++ b/run.sh @@ -5,16 +5,18 @@ TASK=$2 GPUS=$3 MORE_ARG=${*:4} +RANDOM_MASTER=$(shuf -i 2000-65000 -n 1) + _command="print(len('${GPUS}'.split(',')))" GPU_COUNT=$(python3 -c "${_command}") echo "GPU_COUNT:${GPU_COUNT}" echo CUDA_VISIBLE_DEVICES=$GPUS \ -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ + PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG" CUDA_VISIBLE_DEVICES=$GPUS \ -PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ + PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ + --master_port=${RANDOM_MASTER} \ main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed -