TAFG good result
This commit is contained in:
parent
87cbcf34d3
commit
7ea9c6d0d5
@ -23,7 +23,8 @@ model:
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 24
|
||||
num_blocks: 8
|
||||
num_adain_blocks: 8
|
||||
num_res_blocks: 0
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
@ -47,21 +48,17 @@ loss:
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L2'
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0.5
|
||||
weight: 10
|
||||
style:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L2'
|
||||
"3": 1
|
||||
criterion: 'L1'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 0
|
||||
weight: 10
|
||||
fm:
|
||||
level: 1
|
||||
weight: 10
|
||||
@ -71,6 +68,9 @@ loss:
|
||||
style_recon:
|
||||
level: 1
|
||||
weight: 0
|
||||
edge:
|
||||
weight: 10
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
@ -91,7 +91,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 24
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
|
||||
@ -1,20 +1,17 @@
|
||||
from itertools import chain
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from model.weight_init import generation_init_weights
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def _process_batch(self, batch, inference=False):
|
||||
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
||||
return batch
|
||||
@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
|
||||
batch = self._process_batch(batch)
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
_, loss_style = self.style_loss(generated["a"], batch["a"])
|
||||
loss["style"] = self.config.loss.style.weight * loss_style
|
||||
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
|
||||
for phase in "ab":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
# )
|
||||
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .base import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .base import ResidualBlock
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
@ -169,25 +170,37 @@ class StyleGenerator(nn.Module):
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
|
||||
num_adain_blocks=8, num_res_blocks=4,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.num_adain_blocks=num_adain_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_resnet_a = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.adain_resnet_b = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.resnet = nn.ModuleDict({
|
||||
"a": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
]),
|
||||
"b": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
])
|
||||
})
|
||||
self.adain_resnet = nn.ModuleDict({
|
||||
"a": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
]),
|
||||
"b": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
])
|
||||
})
|
||||
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||
@ -196,10 +209,10 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
x = self.resnet[which_decoder](x)
|
||||
styles = self.style_encoders[which_decoder](style_img)
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
|
||||
for i, ar in enumerate(resnet):
|
||||
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_resnet[which_decoder]):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
@ -8,33 +10,36 @@ import numpy as np
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
|
||||
def save(outdir: pathlib.Path, tag, event_acc):
|
||||
events = event_acc.Images(tag)
|
||||
|
||||
for index, event in enumerate(events):
|
||||
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
|
||||
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
|
||||
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
|
||||
cv2.imwrite(outpath.as_posix(), image)
|
||||
|
||||
|
||||
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
|
||||
# https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', type=str, required=True)
|
||||
parser.add_argument('--outdir', type=str, required=True)
|
||||
parser.add_argument("--tag", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
event_acc = event_accumulator.EventAccumulator(
|
||||
args.path, size_guidance={'images': 0})
|
||||
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
|
||||
event_acc.Reload()
|
||||
|
||||
outdir = pathlib.Path(args.outdir)
|
||||
outdir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for tag in event_acc.Tags()['images']:
|
||||
events = event_acc.Images(tag)
|
||||
|
||||
tag_name = tag.replace('/', '_')
|
||||
dirpath = outdir / tag_name
|
||||
dirpath.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for index, event in enumerate(events):
|
||||
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
|
||||
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
|
||||
outpath = dirpath / '{:04}.jpg'.format(index)
|
||||
cv2.imwrite(outpath.as_posix(), image)
|
||||
if args.tag is None:
|
||||
for tag in event_acc.Tags()['images']:
|
||||
save(outdir, tag, event_acc)
|
||||
else:
|
||||
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
|
||||
save(outdir, args.tag, event_acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user