From acf243cb12d7d43564b09f7868cea2072be821d5 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Fri, 25 Sep 2020 18:31:12 +0800 Subject: [PATCH] working --- configs/synthesizers/TAFG.yml | 1 + configs/synthesizers/TSIT.yml | 29 +++-- configs/synthesizers/talking_anime.yml | 171 +++++++++++++++++++++++++ engine/TAFG.py | 133 ++++++++++++------- engine/TSIT.py | 16 +-- engine/base/i2i.py | 49 ++++--- engine/talking_anime.py | 153 ++++++++++++++++++++++ model/GAN/TAFG.py | 53 ++++++++ model/GAN/TSIT.py | 25 ++-- tool/inspect_model.py | 14 ++ tool/process/permutation_face.py | 13 ++ 11 files changed, 542 insertions(+), 115 deletions(-) create mode 100644 configs/synthesizers/talking_anime.yml create mode 100644 engine/talking_anime.py create mode 100644 tool/inspect_model.py create mode 100644 tool/process/permutation_face.py diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index 1e2d133..a43a652 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -19,6 +19,7 @@ handler: misc: random_seed: 1004 + add_new_loss_epoch: -1 model: generator: diff --git a/configs/synthesizers/TSIT.yml b/configs/synthesizers/TSIT.yml index b2192a3..4b40779 100644 --- a/configs/synthesizers/TSIT.yml +++ b/configs/synthesizers/TSIT.yml @@ -1,4 +1,4 @@ -name: self2anime-TSIT +name: VoxCeleb2Anime-TSIT engine: TSIT result_dir: ./result max_pairs: 1500000 @@ -11,7 +11,10 @@ handler: n_saved: 2 tensorboard: scalar: 100 # log scalar `scalar` times per epoch - image: 2 # log image `image` times per epoch + image: 4 # log image `image` times per epoch + test: + random: True + images: 10 misc: @@ -86,24 +89,23 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 1 + batch_size: 8 shuffle: True num_workers: 2 pin_memory: True drop_last: True dataset: - _type: GenerationUnpairedDatasetWithEdge - root_a: "/data/i2i/VoxCeleb2Anime/trainA" - root_b: "/data/i2i/VoxCeleb2Anime/trainB" - edges_path: "/data/i2i/VoxCeleb2Anime/edges" - landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" - edge_type: "landmark_hed" - size: [ 128, 128 ] + _type: GenerationUnpairedDataset + root_a: "/data/i2i/faces/CelebA-Asian/trainA" + root_b: "/data/i2i/anime/your-name/faces" random_pair: True pipeline: - Load - Resize: + size: [ 170, 144 ] + - RandomCrop: size: [ 128, 128 ] + - RandomHorizontalFlip - ToTensor - Normalize: mean: [ 0.5, 0.5, 0.5 ] @@ -118,13 +120,14 @@ data: drop_last: False dataset: _type: GenerationUnpairedDataset - root_a: "/data/i2i/VoxCeleb2Anime/testA" - root_b: "/data/i2i/VoxCeleb2Anime/testB" - with_path: True + root_a: "/data/i2i/faces/CelebA-Asian/testA" + root_b: "/data/i2i/anime/your-name/faces" random_pair: False pipeline: - Load - Resize: + size: [ 170, 144 ] + - RandomCrop: size: [ 128, 128 ] - ToTensor - Normalize: diff --git a/configs/synthesizers/talking_anime.yml b/configs/synthesizers/talking_anime.yml new file mode 100644 index 0000000..c9cf4aa --- /dev/null +++ b/configs/synthesizers/talking_anime.yml @@ -0,0 +1,171 @@ +name: talking_anime +engine: talking_anime +result_dir: ./result +max_pairs: 1000000 + +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 100 # log image `image` times per epoch + test: + random: True + images: 10 + +misc: + random_seed: 1004 + +loss: + gan: + loss_type: hinge + real_label_val: 1.0 + fake_label_val: 0.0 + weight: 1.0 + fm: + level: 1 + weight: 1 + style: + layer_weights: + "3": 1 + criterion: 'L1' + style_loss: True + perceptual_loss: False + weight: 10 + perceptual: + layer_weights: + "1": 0.03125 + "6": 0.0625 + "11": 0.125 + "20": 0.25 + "29": 1 + criterion: 'L1' + style_loss: False + perceptual_loss: True + weight: 0 + context: + layer_weights: + #"13": 1 + "22": 1 + weight: 5 + recon: + level: 1 + weight: 10 + edge: + weight: 5 + hed_pretrained_model_path: ./network-bsds500.pytorch + +model: + face_generator: + _type: TAFG-SingleGenerator + _bn_to_sync_bn: False + style_in_channels: 3 + content_in_channels: 1 + use_spectral_norm: True + style_encoder_type: VGG19StyleEncoder + num_style_conv: 4 + style_dim: 512 + num_adain_blocks: 8 + num_res_blocks: 8 + anime_generator: + _type: TAFG-ResGenerator + _bn_to_sync_bn: False + in_channels: 6 + use_spectral_norm: True + num_res_blocks: 8 + + discriminator: + _type: MultiScaleDiscriminator + num_scale: 2 + discriminator_cfg: + _type: PatchDiscriminator + in_channels: 3 + base_channels: 64 + use_spectral: True + need_intermediate_feature: True + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 4e-4 + betas: [ 0, 0.9 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + dataloader: + batch_size: 8 + shuffle: True + num_workers: 1 + pin_memory: True + drop_last: True + dataset: + _type: PoseFacesWithSingleAnime + root_face: "/data/i2i/VoxCeleb2Anime/trainA" + root_anime: "/data/i2i/VoxCeleb2Anime/trainB" + landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks" + num_face: 2 + img_size: [ 128, 128 ] + with_order: False + face_pipeline: + - Load + - Resize: + size: [ 128, 128 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + anime_pipeline: + - Load + - Resize: + size: [ 144, 144 ] + - RandomCrop: + size: [ 128, 128 ] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + which: dataset + dataloader: + batch_size: 1 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: PoseFacesWithSingleAnime + root_face: "/data/i2i/VoxCeleb2Anime/testA" + root_anime: "/data/i2i/VoxCeleb2Anime/testB" + landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks" + num_face: 2 + img_size: [ 128, 128 ] + with_order: False + face_pipeline: + - Load + - Resize: + size: [ 128, 128 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + anime_pipeline: + - Load + - Resize: + size: [ 128, 128 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] \ No newline at end of file diff --git a/engine/TAFG.py b/engine/TAFG.py index f1d7379..a854875 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -76,11 +76,14 @@ class TAFGEngineKernel(EngineKernel): contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b") for ph in "ab": images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph) - images["a2b"] = generator.decode(contents["a"], styles["b"], "b") - contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]), - images["a2b"], "b", "b") - images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b") - images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a") + + if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: + styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device()) + images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b") + contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]), + images["a2b"], "b", "b") + images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b") + images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a") return dict(styles=styles, contents=contents, images=images) def criterion_generators(self, batch, generated) -> dict: @@ -91,50 +94,76 @@ class TAFGEngineKernel(EngineKernel): loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss( generated["images"][f"{ph}2{ph}"], batch[ph]["img"]) - pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"]) + pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"]) loss[f"gan_{ph}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight - loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( - generated["contents"]["a"], generated["contents"]["recon_a"] - ) - loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( - generated["styles"]["b"], generated["styles"]["recon_b"] - ) - if self.config.loss.perceptual.weight > 0: - loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss( - batch["a"]["img"], generated["images"]["a2b"] + if self.engine.state.epoch == self.config.misc.add_new_loss_epoch: + self.generators["main"].style_converters.requires_grad_(False) + self.generators["main"].style_encoders.requires_grad_(False) + + if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: + + pred_fake = self.discriminators[ph](generated["images"]["a2b"]) + loss["gan_a2b"] = 0 + for sub_pred_fake in pred_fake: + # last output is actual prediction + loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight + + loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( + generated["contents"]["a"], generated["contents"]["recon_a"] + ) + loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( + generated["styles"]["random_b"], generated["styles"]["recon_b"] ) - for ph in "ab": + if self.config.loss.perceptual.weight > 0: + loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss( + batch["a"]["img"], generated["images"]["a2b"] + ) + if self.config.loss.cycle.weight > 0: - loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss( - batch[ph]["img"], generated["images"][f"cycle_{ph}"] - ) - if self.config.loss.style.weight > 0: - loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss( - batch[ph]["img"], generated["images"][f"a2{ph}"] + loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss( + batch["a"]["img"], generated["images"][f"cycle_a"] ) - if self.config.loss.edge.weight > 0: - loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss( - generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :] - ) + # for ph in "ab": + # + # if self.config.loss.style.weight > 0: + # loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss( + # batch[ph]["img"], generated["images"][f"a2{ph}"] + # ) + + if self.config.loss.edge.weight > 0: + loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss( + generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :] + ) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() - # batch = self._process_batch(batch) - for phase in self.discriminators.keys(): - pred_real = self.discriminators[phase](batch[phase]["img"]) - pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach()) - loss[f"gan_{phase}"] = 0 - for i in range(len(pred_fake)): - loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) - + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 + + if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: + for phase in self.discriminators.keys(): + pred_real = self.discriminators[phase](batch[phase]["img"]) + pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach()) + pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach()) + loss[f"gan_{phase}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3 + else: + for phase in self.discriminators.keys(): + pred_real = self.discriminators[phase](batch[phase]["img"]) + pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach()) + loss[f"gan_{phase}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 return loss def intermediate_images(self, batch, generated) -> dict: @@ -145,18 +174,30 @@ class TAFGEngineKernel(EngineKernel): :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ batch = self._process_batch(batch) - return dict( - a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), - batch["a"]["img"].detach(), - generated["images"]["a2a"].detach(), - generated["images"]["a2b"].detach(), - generated["images"]["cycle_a"].detach(), - ], - b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), - batch["b"]["img"].detach(), - generated["images"]["b2b"].detach(), - generated["images"]["cycle_b"].detach()] - ) + if self.engine.state.epoch > self.config.misc.add_new_loss_epoch: + return dict( + a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), + batch["a"]["img"].detach(), + generated["images"]["a2a"].detach(), + generated["images"]["a2b"].detach(), + generated["images"]["cycle_a"].detach(), + ], + b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), + batch["b"]["img"].detach(), + generated["images"]["b2b"].detach(), + generated["images"]["cycle_b"].detach()] + ) + else: + return dict( + a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), + batch["a"]["img"].detach(), + generated["images"]["a2a"].detach(), + ], + b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), + batch["b"]["img"].detach(), + generated["images"]["b2b"].detach(), + ] + ) def change_engine(self, config, trainer): pass diff --git a/engine/TSIT.py b/engine/TSIT.py index f93838c..7dc40af 100644 --- a/engine/TSIT.py +++ b/engine/TSIT.py @@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel): def forward(self, batch, inference=False) -> dict: with torch.set_grad_enabled(not inference): fake = dict( - b=self.generators["main"](content_img=batch["a"], style_img=batch["b"]) + b=self.generators["main"](content_img=batch["a"]) ) return fake def criterion_generators(self, batch, generated) -> dict: loss = dict() - loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) - loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight + loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight for phase in "b": pred_fake = self.discriminators[phase](generated[phase]) loss[f"gan_{phase}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True) - - if self.config.loss.fm.weight > 0 and phase == "b": - pred_real = self.discriminators[phase](batch[phase]) - loss_fm = 0 - num_scale_discriminator = len(pred_fake) - for i in range(num_scale_discriminator): - # last output is the final prediction, so we exclude it - num_intermediate_outputs = len(pred_fake[i]) - 1 - for j in range(num_intermediate_outputs): - loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator - loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm return loss def criterion_discriminators(self, batch, generated) -> dict: diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 3454e00..95c5897 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -189,34 +189,33 @@ def get_trainer(config, kernel: EngineKernel): for i in range(len(image_list)): test_images[k].append([]) - with torch.no_grad(): - g = torch.Generator() - g.manual_seed(config.misc.random_seed + engine.state.epoch - if config.handler.test.random else config.misc.random_seed) - random_start = \ - torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0] - for i in range(random_start, random_start + config.handler.test.images): - batch = convert_tensor(engine.state.test_dataset[i], idist.device()) - for k in batch: - if isinstance(batch[k], torch.Tensor): - batch[k] = batch[k].unsqueeze(0) - elif isinstance(batch[k], dict): - for kk in batch[k]: - if isinstance(batch[k][kk], torch.Tensor): - batch[k][kk] = batch[k][kk].unsqueeze(0) + g = torch.Generator() + g.manual_seed(config.misc.random_seed + engine.state.epoch + if config.handler.test.random else config.misc.random_seed) + random_start = \ + torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0] + for i in range(random_start, random_start + config.handler.test.images): + batch = convert_tensor(engine.state.test_dataset[i], idist.device()) + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].unsqueeze(0) + elif isinstance(batch[k], dict): + for kk in batch[k]: + if isinstance(batch[k][kk], torch.Tensor): + batch[k][kk] = batch[k][kk].unsqueeze(0) - generated = kernel.forward(batch) - images = kernel.intermediate_images(batch, generated) + generated = kernel.forward(batch, inference=True) + images = kernel.intermediate_images(batch, generated) - for k in test_images: - for j in range(len(images[k])): - test_images[k][j].append(images[k][j]) for k in test_images: - tensorboard_handler.writer.add_image( - f"test/{k}", - make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)), - engine.state.iteration * pairs_per_iteration - ) + for j in range(len(images[k])): + test_images[k][j].append(images[k][j]) + for k in test_images: + tensorboard_handler.writer.add_image( + f"test/{k}", + make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)), + engine.state.iteration * pairs_per_iteration + ) return trainer diff --git a/engine/talking_anime.py b/engine/talking_anime.py new file mode 100644 index 0000000..6264a28 --- /dev/null +++ b/engine/talking_anime.py @@ -0,0 +1,153 @@ +from itertools import chain + +import ignite.distributed as idist +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from engine.base.i2i import EngineKernel, run_kernel +from engine.util.build import build_model +from loss.I2I.context_loss import ContextLoss +from loss.I2I.edge_loss import EdgeLoss +from loss.I2I.perceptual_loss import PerceptualLoss +from loss.gan import GANLoss +from model.weight_init import generation_init_weights + + +class TAEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + + perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) + perceptual_loss_cfg.pop("weight") + self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + + style_loss_cfg = OmegaConf.to_container(config.loss.style) + style_loss_cfg.pop("weight") + self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + + context_loss_cfg = OmegaConf.to_container(config.loss.context) + context_loss_cfg.pop("weight") + self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device()) + + self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() + self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() + + self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( + idist.device()) + + def build_models(self) -> (dict, dict): + generators = dict( + anime=build_model(self.config.model.anime_generator), + face=build_model(self.config.model.face_generator) + ) + discriminators = dict( + anime=build_model(self.config.model.discriminator), + face=build_model(self.config.model.discriminator) + ) + self.logger.debug(discriminators["face"]) + self.logger.debug(generators["face"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_after_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + with torch.set_grad_enabled(not inference): + target_pose_anime = self.generators["anime"]( + torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1)) + target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"]) + + return dict(fake_anime=target_pose_anime, fake_face=target_pose_face) + + def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None): + pred_fake = discriminator(generated_img) + loss_gan = 0 + for sub_pred_fake in pred_fake: + # last output is actual prediction + loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True) + + if match_img is None: + # do not cal feature match loss + return loss_gan, 0 + + pred_real = discriminator(match_img) + loss_fm = 0 + num_scale_discriminator = len(pred_fake) + for i in range(num_scale_discriminator): + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): + loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator + loss_fm = self.config.loss.fm.weight * loss_fm + return loss_gan, loss_fm + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + loss["face_style"] = self.config.loss.style.weight * self.style_loss( + generated["fake_face"], batch["face_1"] + ) + loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss( + generated["fake_face"], batch["face_1"] + ) + loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss( + self.discriminators["face"], generated["fake_face"], batch["face_1"]) + loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss( + self.discriminators["anime"], generated["fake_anime"], batch["anime_img"]) + + loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss( + generated["fake_anime"], batch["face_1"], gt_is_edge=False, + ) + if self.config.loss.perceptual.weight > 0: + loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss( + generated["fake_anime"], batch["anime_img"] + ) + if self.config.loss.context.weight > 0: + loss["anime_context"] = self.config.loss.context.weight * self.context_loss( + generated["fake_anime"], batch["anime_img"], + ) + + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + real = {"anime": "anime_img", "face": "face_1"} + for phase in self.discriminators.keys(): + pred_real = self.discriminators[phase](batch[real[phase]]) + pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach()) + loss[f"gan_{phase}"] = 0 + for i in range(len(pred_fake)): + loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) + + self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2 + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(), + generated["fake_face"].detach()] + return dict( + b=[img for img in images] + ) + + +def run(task, config, _): + kernel = TAEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/model/GAN/TAFG.py b/model/GAN/TAFG.py index 5b44b74..a1fb0ec 100644 --- a/model/GAN/TAFG.py +++ b/model/GAN/TAFG.py @@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module): return x.view(x.size(0), -1) +@MODEL.register_module("TAFG-ResGenerator") +class ResGenerator(nn.Module): + def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64): + super().__init__() + self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks, + use_spectral_norm=use_spectral_norm) + resnet_channels = 2 ** 2 * base_channels + self.decoder = Decoder(resnet_channels, out_channels, 2, + 0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect") + + def forward(self, x): + return self.decoder(self.content_encoder(x)) + + +@MODEL.register_module("TAFG-SingleGenerator") +class SingleGenerator(nn.Module): + def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False, + style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8, + num_res_blocks=8, base_channels=64, padding_mode="reflect"): + super().__init__() + self.num_adain_blocks = num_adain_blocks + if style_encoder_type == "StyleEncoder": + self.style_encoder = StyleEncoder( + style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm, + max_multiple=4, padding_mode=padding_mode, norm_type="NONE" + ) + elif style_encoder_type == "VGG19StyleEncoder": + self.style_encoder = VGG19StyleEncoder( + style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE" + ) + else: + raise NotImplemented(f"do not support {style_encoder_type}") + + resnet_channels = 2 ** 2 * base_channels + self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, + n_blocks=3, norm_type="NONE") + self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks, + use_spectral_norm=use_spectral_norm) + + self.decoder = Decoder(resnet_channels, out_channels, 2, + num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode) + + def forward(self, content_img, style_img): + content = self.content_encoder(content_img) + style = self.style_encoder(style_img) + as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1) + # set style for decoder + for i, blk in enumerate(self.decoder.res_blocks): + blk.conv1.normalization.set_style(as_param_style[2 * i]) + blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) + return self.decoder(content) + + @MODEL.register_module("TAFG-Generator") class Generator(nn.Module): def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False, diff --git a/model/GAN/TSIT.py b/model/GAN/TSIT.py index 1a4f429..de9d467 100644 --- a/model/GAN/TSIT.py +++ b/model/GAN/TSIT.py @@ -3,7 +3,6 @@ import torch.nn as nn import torch.nn.functional as F from model import MODEL -from model.normalization import AdaptiveInstanceNorm2d from model.normalization import select_norm_layer @@ -62,7 +61,9 @@ class Interpolation(nn.Module): class FADE(nn.Module): def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True): super().__init__() - self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) + # self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) + self.norm = nn.InstanceNorm2d(num_features=in_channels) + self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, padding_mode="zeros") self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, @@ -71,7 +72,7 @@ class FADE(nn.Module): def forward(self, x, feature): alpha = self.alpha_conv(feature) beta = self.beta_conv(feature) - x = self.bn(x) + x = self.norm(x) return alpha * x + beta @@ -122,9 +123,7 @@ class TSITGenerator(nn.Module): self.use_spectral = use_spectral self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type) - self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type) self.content_stream = self.build_stream() - self.style_stream = self.build_stream() self.generator = self.build_generator() self.end_conv = nn.Sequential( conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"), @@ -138,11 +137,9 @@ class TSITGenerator(nn.Module): m = self.num_blocks - i multiple_prev = multiple_now multiple_now = min(2 ** m, 2 ** 4) - stream_sequence.append(nn.Sequential( - AdaptiveInstanceNorm2d(multiple_prev * self.base_channels), + stream_sequence.append( FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels, - multiple_now * self.base_channels) - )) + multiple_now * self.base_channels)) return nn.ModuleList(stream_sequence) def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"): @@ -171,22 +168,16 @@ class TSITGenerator(nn.Module): )) return nn.ModuleList(stream_sequence) - def forward(self, content_img, style_img): + def forward(self, content_img): c = self.content_input_layer(content_img) - s = self.style_input_layer(style_img) content_features = [] - style_features = [] for i in range(self.num_blocks): - s = self.style_stream[i](s) c = self.content_stream[i](c) content_features.append(c) - style_features.append(s) z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) for i in range(self.num_blocks): m = - i - 1 layer = self.generator[i] - layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1)) - z = layer[0](z) - z = layer[1](z, content_features[m]) + z = layer(z, content_features[m]) return self.end_conv(z) diff --git a/tool/inspect_model.py b/tool/inspect_model.py new file mode 100644 index 0000000..d4bfe6d --- /dev/null +++ b/tool/inspect_model.py @@ -0,0 +1,14 @@ +import sys +import torch +from omegaconf import OmegaConf + +from engine.util.build import build_model + +config = OmegaConf.load(sys.argv[1]) + + +generator = build_model(config.model.generator) + +ckp = torch.load(sys.argv[2], map_location="cpu") + +generator.module.load_state_dict(ckp["generator_main"]) diff --git a/tool/process/permutation_face.py b/tool/process/permutation_face.py new file mode 100644 index 0000000..9b9d219 --- /dev/null +++ b/tool/process/permutation_face.py @@ -0,0 +1,13 @@ +from pathlib import Path +import sys +from collections import defaultdict +from itertools import permutations + +pids = defaultdict(list) +for p in Path(sys.argv[1]).glob("*.jpg"): + pids[p.stem[:7]].append(p.stem) + +data = [] +for p in pids: + data.extend(list(permutations(pids[p], 2))) +