TAFG 0.01

This commit is contained in:
budui 2020-09-03 09:34:38 +08:00
parent 14d4247112
commit 2469bf15fe
6 changed files with 37 additions and 388 deletions

View File

@ -49,7 +49,7 @@ loss:
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
weight: 5
style:
layer_weights:
"1": 0.03125
@ -63,10 +63,10 @@ loss:
weight: 0
fm:
level: 1
weight: 1
weight: 10
recon:
level: 1
weight: 1
weight: 5
optimizers:
generator:
@ -97,7 +97,8 @@ data:
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_canny"
size: [128, 128]
random_pair: True
pipeline:

View File

@ -1,129 +0,0 @@
name: TAHG
engine: TAHG
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 100
image: 2
model:
generator:
_type: TAHG-Generator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
discriminator:
_type: TAHG-Discriminator
in_channels: 3
loss:
gan:
loss_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
edge:
criterion: 'L1'
hed_pretrained_model_path: "./network-bsds500.pytorch"
weight: 1
perceptual:
layer_weights:
"3": 1.0
# "0": 1.0
# "5": 1.0
# "10": 1.0
# "19": 1.0
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 20
recon:
level: 1
weight: 1
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 160
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
size: [128, 128]
random_pair: True
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
random_pair: False
size: [128, 128]
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -15,6 +15,7 @@ from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
from .util import dlib_landmark
def default_transform_way(transform, sample):
@ -178,20 +179,38 @@ class GenerationUnpairedDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return F.to_tensor(img)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
part_labels = torch.from_numpy(part_labels)
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return torch.cat([edges, dist_tensor, part_labels], dim=0)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
@ -200,7 +219,6 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):

View File

@ -65,12 +65,14 @@ class TAFGEngineKernel(EngineKernel):
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
for i, sub_pred_fake in enumerate(pred_fake):
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])

View File

@ -1,245 +0,0 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger, train_data_loader):
generator = build_model(config.model.generator, config.distributed.model)
discriminators = dict(
a=build_model(config.model.discriminator, config.distributed.model),
b=build_model(config.model.discriminator, config.distributed.model),
)
generation_init_weights(generator)
for m in discriminators.values():
generation_init_weights(m)
logger.debug(discriminators["a"])
logger.debug(generator)
optimizers = dict(
g=build_optimizer(generator.parameters(), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
edge_loss_cfg.pop("weight")
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
optimizers["g"].zero_grad()
loss_g = dict()
for d in "ab":
discriminators[d].requires_grad_(False)
pred_fake = discriminators[d](fake[d])
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
sum(loss_g.values()).backward()
optimizers["g"].step()
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
pred_real = discriminators[k](real[k])
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
gan_loss(pred_fake, False, is_discriminator=True)) / 2
sum(loss_d.values()).backward()
optimizers["d"].step()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img["rec_b"] = rec_b.detach()
generated_img["rec_bb"] = rec_b.detach()
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": generated_img
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({"generator": generator})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
iter_per_epoch = len(train_data_loader)
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["edge_a", "real_a", "fake_a", "fake_b"],
b=["edge_b", "real_b", "rec_b", "rec_bb"]
)
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
test_images["a"][0].append(batch["edge_a"])
test_images["a"][1].append(batch["a"])
test_images["a"][2].append(fake["a"])
test_images["a"][3].append(fake["b"])
test_images["b"][0].append(batch["edge_b"])
test_images["b"][1].append(batch["b"])
test_images["b"][2].append(rec_b)
test_images["b"][3].append(rec_bb)
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger, train_data_loader)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -145,6 +145,7 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
@ -159,7 +160,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
@ -182,6 +184,6 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
engine.state.iteration
engine.state.iteration * pairs_per_iteration
)
return trainer