From 7ea9c6d0d5bbfbab96f47e45624a7ecdcdb4980a Mon Sep 17 00:00:00 2001 From: budui Date: Wed, 9 Sep 2020 14:46:07 +0800 Subject: [PATCH] TAFG good result --- configs/synthesizers/TAFG.yml | 22 ++++++++--------- engine/TAFG.py | 29 ++++++++++++---------- model/GAN/TAFG.py | 45 ++++++++++++++++++++++------------- tool/dump_tensorboard.py | 35 +++++++++++++++------------ 4 files changed, 76 insertions(+), 55 deletions(-) diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index df2aff7..2ad4419 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -23,7 +23,8 @@ model: _bn_to_sync_bn: False style_in_channels: 3 content_in_channels: 24 - num_blocks: 8 + num_adain_blocks: 8 + num_res_blocks: 0 discriminator: _type: MultiScaleDiscriminator num_scale: 2 @@ -47,21 +48,17 @@ loss: "11": 0.125 "20": 0.25 "29": 1 - criterion: 'L2' + criterion: 'L1' style_loss: False perceptual_loss: True - weight: 0.5 + weight: 10 style: layer_weights: - "1": 0.03125 - "6": 0.0625 - "11": 0.125 - "20": 0.25 - "29": 1 - criterion: 'L2' + "3": 1 + criterion: 'L1' style_loss: True perceptual_loss: False - weight: 0 + weight: 10 fm: level: 1 weight: 10 @@ -71,6 +68,9 @@ loss: style_recon: level: 1 weight: 0 + edge: + weight: 10 + hed_pretrained_model_path: ./network-bsds500.pytorch optimizers: generator: @@ -91,7 +91,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 24 + batch_size: 8 shuffle: True num_workers: 2 pin_memory: True diff --git a/engine/TAFG.py b/engine/TAFG.py index 77d3646..20a253c 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -1,20 +1,17 @@ from itertools import chain -from omegaconf import OmegaConf - +import ignite.distributed as idist import torch import torch.nn as nn -import ignite.distributed as idist from ignite.engine import Events - from omegaconf import read_write, OmegaConf -from model.weight_init import generation_init_weights -from loss.I2I.perceptual_loss import PerceptualLoss -from loss.gan import GANLoss - from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model +from loss.I2I.edge_loss import EdgeLoss +from loss.I2I.perceptual_loss import PerceptualLoss +from loss.gan import GANLoss +from model.weight_init import generation_init_weights class TAFGEngineKernel(EngineKernel): @@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel): perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + style_loss_cfg = OmegaConf.to_container(config.loss.style) + style_loss_cfg.pop("weight") + self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) @@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel): self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() + self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( + idist.device()) + def _process_batch(self, batch, inference=False): # batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size()) return batch @@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel): batch = self._process_batch(batch) loss = dict() loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) - loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight + _, loss_style = self.style_loss(generated["a"], batch["a"]) + loss["style"] = self.config.loss.style.weight * loss_style + loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual for phase in "ab": pred_fake = self.discriminators[phase](generated[phase]) loss[f"gan_{phase}"] = 0 @@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel): loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) - # loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( - # self.generators["main"].module.style_encoders["b"](batch["b"]), - # self.generators["main"].module.style_encoders["b"](generated["b"]) - # ) + loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :]) return loss def criterion_discriminators(self, batch, generated) -> dict: diff --git a/model/GAN/TAFG.py b/model/GAN/TAFG.py index a3cf097..33aed3f 100644 --- a/model/GAN/TAFG.py +++ b/model/GAN/TAFG.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn -from .base import ResidualBlock -from model.registry import MODEL from torchvision.models import vgg19 + from model.normalization import select_norm_layer +from model.registry import MODEL +from .base import ResidualBlock class VGG19StyleEncoder(nn.Module): @@ -169,25 +170,37 @@ class StyleGenerator(nn.Module): @MODEL.register_module("TAFG-Generator") class Generator(nn.Module): - def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8, + def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, + num_adain_blocks=8, num_res_blocks=4, base_channels=64, padding_mode="reflect"): super(Generator, self).__init__() - self.num_blocks = num_blocks + self.num_adain_blocks=num_adain_blocks self.style_encoders = nn.ModuleDict({ - "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, + "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, base_channels=base_channels, padding_mode=padding_mode), - "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, + "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, base_channels=base_channels, padding_mode=padding_mode), }) - self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks, + self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8, padding_mode=padding_mode, norm_type="IN") res_block_channels = 2 ** 2 * base_channels - self.adain_resnet_a = nn.ModuleList([ - ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) - ]) - self.adain_resnet_b = nn.ModuleList([ - ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) - ]) + + self.resnet = nn.ModuleDict({ + "a": nn.Sequential(*[ + ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) + ]), + "b": nn.Sequential(*[ + ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) + ]) + }) + self.adain_resnet = nn.ModuleDict({ + "a": nn.ModuleList([ + ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks) + ]), + "b": nn.ModuleList([ + ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks) + ]) + }) self.decoders = nn.ModuleDict({ "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), @@ -196,10 +209,10 @@ class Generator(nn.Module): def forward(self, content_img, style_img, which_decoder: str = "a"): x = self.content_encoder(content_img) + x = self.resnet[which_decoder](x) styles = self.style_encoders[which_decoder](style_img) - styles = torch.chunk(styles, self.num_blocks * 2, dim=1) - resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b - for i, ar in enumerate(resnet): + styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1) + for i, ar in enumerate(self.adain_resnet[which_decoder]): ar.norm1.set_style(styles[2 * i]) ar.norm2.set_style(styles[2 * i + 1]) x = ar(x) diff --git a/tool/dump_tensorboard.py b/tool/dump_tensorboard.py index dd9fb3b..e65c355 100644 --- a/tool/dump_tensorboard.py +++ b/tool/dump_tensorboard.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2 + import argparse import pathlib @@ -8,33 +10,36 @@ import numpy as np from tensorboard.backend.event_processing import event_accumulator +def save(outdir: pathlib.Path, tag, event_acc): + events = event_acc.Images(tag) + + for index, event in enumerate(events): + s = np.frombuffer(event.encoded_image_string, dtype=np.uint8) + image = cv2.imdecode(s, cv2.IMREAD_COLOR) + outpath = outdir / f"{tag.replace('/', '_')}@{index}.png" + cv2.imwrite(outpath.as_posix(), image) + + # ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4 -# https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2 def main(): parser = argparse.ArgumentParser() parser.add_argument('--path', type=str, required=True) parser.add_argument('--outdir', type=str, required=True) + parser.add_argument("--tag", type=str, required=False) args = parser.parse_args() - event_acc = event_accumulator.EventAccumulator( - args.path, size_guidance={'images': 0}) + event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0}) event_acc.Reload() outdir = pathlib.Path(args.outdir) outdir.mkdir(exist_ok=True, parents=True) - for tag in event_acc.Tags()['images']: - events = event_acc.Images(tag) - - tag_name = tag.replace('/', '_') - dirpath = outdir / tag_name - dirpath.mkdir(exist_ok=True, parents=True) - - for index, event in enumerate(events): - s = np.frombuffer(event.encoded_image_string, dtype=np.uint8) - image = cv2.imdecode(s, cv2.IMREAD_COLOR) - outpath = dirpath / '{:04}.jpg'.format(index) - cv2.imwrite(outpath.as_posix(), image) + if args.tag is None: + for tag in event_acc.Tags()['images']: + save(outdir, tag, event_acc) + else: + assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}" + save(outdir, args.tag, event_acc) if __name__ == '__main__':