TAFG good result

This commit is contained in:
budui 2020-09-09 14:46:07 +08:00
parent 87cbcf34d3
commit 7ea9c6d0d5
4 changed files with 76 additions and 55 deletions

View File

@ -23,7 +23,8 @@ model:
_bn_to_sync_bn: False _bn_to_sync_bn: False
style_in_channels: 3 style_in_channels: 3
content_in_channels: 24 content_in_channels: 24
num_blocks: 8 num_adain_blocks: 8
num_res_blocks: 0
discriminator: discriminator:
_type: MultiScaleDiscriminator _type: MultiScaleDiscriminator
num_scale: 2 num_scale: 2
@ -47,21 +48,17 @@ loss:
"11": 0.125 "11": 0.125
"20": 0.25 "20": 0.25
"29": 1 "29": 1
criterion: 'L2' criterion: 'L1'
style_loss: False style_loss: False
perceptual_loss: True perceptual_loss: True
weight: 0.5 weight: 10
style: style:
layer_weights: layer_weights:
"1": 0.03125 "3": 1
"6": 0.0625 criterion: 'L1'
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True style_loss: True
perceptual_loss: False perceptual_loss: False
weight: 0 weight: 10
fm: fm:
level: 1 level: 1
weight: 10 weight: 10
@ -71,6 +68,9 @@ loss:
style_recon: style_recon:
level: 1 level: 1
weight: 0 weight: 0
edge:
weight: 10
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers: optimizers:
generator: generator:
@ -91,7 +91,7 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 24 batch_size: 8
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True

View File

@ -1,20 +1,17 @@
from itertools import chain from itertools import chain
from omegaconf import OmegaConf import ignite.distributed as idist
import torch import torch
import torch.nn as nn import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events from ignite.engine import Events
from omegaconf import read_write, OmegaConf from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel): class TAFGEngineKernel(EngineKernel):
@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight") perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False): def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size()) # batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch return batch
@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
batch = self._process_batch(batch) batch = self._process_batch(batch)
loss = dict() loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight _, loss_style = self.style_loss(generated["a"], batch["a"])
loss["style"] = self.config.loss.style.weight * loss_style
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
for phase in "ab": for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase]) pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:

View File

@ -1,9 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
from torchvision.models import vgg19 from torchvision.models import vgg19
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
class VGG19StyleEncoder(nn.Module): class VGG19StyleEncoder(nn.Module):
@ -169,25 +170,37 @@ class StyleGenerator(nn.Module):
@MODEL.register_module("TAFG-Generator") @MODEL.register_module("TAFG-Generator")
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8, def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
num_adain_blocks=8, num_res_blocks=4,
base_channels=64, padding_mode="reflect"): base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_blocks = num_blocks self.num_adain_blocks=num_adain_blocks
self.style_encoders = nn.ModuleDict({ self.style_encoders = nn.ModuleDict({
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode), base_channels=base_channels, padding_mode=padding_mode),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
base_channels=base_channels, padding_mode=padding_mode), base_channels=base_channels, padding_mode=padding_mode),
}) })
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks, self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
padding_mode=padding_mode, norm_type="IN") padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels res_block_channels = 2 ** 2 * base_channels
self.adain_resnet_a = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) self.resnet = nn.ModuleDict({
"a": nn.Sequential(*[
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
]),
"b": nn.Sequential(*[
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
]) ])
self.adain_resnet_b = nn.ModuleList([ })
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) self.adain_resnet = nn.ModuleDict({
"a": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]),
"b": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]) ])
})
self.decoders = nn.ModuleDict({ self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
@ -196,10 +209,10 @@ class Generator(nn.Module):
def forward(self, content_img, style_img, which_decoder: str = "a"): def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img) x = self.content_encoder(content_img)
x = self.resnet[which_decoder](x)
styles = self.style_encoders[which_decoder](style_img) styles = self.style_encoders[which_decoder](style_img)
styles = torch.chunk(styles, self.num_blocks * 2, dim=1) styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b for i, ar in enumerate(self.adain_resnet[which_decoder]):
for i, ar in enumerate(resnet):
ar.norm1.set_style(styles[2 * i]) ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1]) ar.norm2.set_style(styles[2 * i + 1])
x = ar(x) x = ar(x)

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
import argparse import argparse
import pathlib import pathlib
@ -8,33 +10,36 @@ import numpy as np
from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import event_accumulator
def save(outdir: pathlib.Path, tag, event_acc):
events = event_acc.Images(tag)
for index, event in enumerate(events):
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
cv2.imwrite(outpath.as_posix(), image)
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4 # ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
# https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True) parser.add_argument('--path', type=str, required=True)
parser.add_argument('--outdir', type=str, required=True) parser.add_argument('--outdir', type=str, required=True)
parser.add_argument("--tag", type=str, required=False)
args = parser.parse_args() args = parser.parse_args()
event_acc = event_accumulator.EventAccumulator( event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
args.path, size_guidance={'images': 0})
event_acc.Reload() event_acc.Reload()
outdir = pathlib.Path(args.outdir) outdir = pathlib.Path(args.outdir)
outdir.mkdir(exist_ok=True, parents=True) outdir.mkdir(exist_ok=True, parents=True)
if args.tag is None:
for tag in event_acc.Tags()['images']: for tag in event_acc.Tags()['images']:
events = event_acc.Images(tag) save(outdir, tag, event_acc)
else:
tag_name = tag.replace('/', '_') assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
dirpath = outdir / tag_name save(outdir, args.tag, event_acc)
dirpath.mkdir(exist_ok=True, parents=True)
for index, event in enumerate(events):
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
outpath = dirpath / '{:04}.jpg'.format(index)
cv2.imwrite(outpath.as_posix(), image)
if __name__ == '__main__': if __name__ == '__main__':