This commit is contained in:
budui 2020-09-01 17:56:18 +08:00
parent e71e8d95d0
commit 14d4247112
11 changed files with 563 additions and 7 deletions

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@ -0,0 +1,145 @@
name: TAFG
engine: TAFG
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 100
image: 2
model:
generator:
_type: TAHG-Generator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: base-PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 0
fm:
level: 1
weight: 1
recon:
level: 1
weight: 1
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 256
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
size: [128, 128]
random_pair: True
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
random_pair: False
size: [128, 128]
pipeline:
- Load
- Resize:
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -3,10 +3,6 @@ engine: TAHG
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
@ -23,6 +19,7 @@ interval:
model:
generator:
_type: TAHG-Generator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4

133
engine/TAFG.py Normal file
View File

@ -0,0 +1,133 @@
from itertools import chain
from math import ceil
from omegaconf import read_write, OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
import data
from engine.base.i2i import get_trainer, EngineKernel, build_model
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
class TAFGEngineKernel(EngineKernel):
def __init__(self, config, logger):
super().__init__(config, logger)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_before_d(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
for i, sub_pred_fake in enumerate(pred_fake):
# last output is actual prediction
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
b=[batch["b"].detach(), generated["b"].detach()]
)
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

0
engine/base/__init__.py Normal file
View File

187
engine/base/i2i.py Normal file
View File

@ -0,0 +1,187 @@
from itertools import chain
from math import ceil
from pathlib import Path
import logging
import torch
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from model import MODEL
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_optimizer
from omegaconf import OmegaConf
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
def build_lr_schedulers(optimizers, config):
# TODO: support more scheduler type
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
class EngineKernel(object):
def __init__(self, config, logger):
self.config = config
self.logger = logger
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
raise NotImplemented
def to_save(self):
to_save = {}
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_before_d(self):
raise NotImplemented
def setup_before_g(self):
raise NotImplemented
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
logger = logging.getLogger(config.name)
generators, discriminators = ek.generators, ek.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info("build optimizers", optimizers)
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = ek.forward(batch)
ek.setup_before_g()
optimizers["g"].zero_grad()
loss_g = ek.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
ek.setup_before_d()
optimizers["d"].zero_grad()
loss_d = ek.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return {
"loss": dict(g=loss_g, d=loss_d),
"img": ek.intermediate_images(batch, generated)
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(ek.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = ek.forward(batch)
images = ek.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
engine.state.iteration
)
return trainer

61
model/GAN/base.py Normal file
View File

@ -0,0 +1,61 @@
import math
import torch.nn as nn
from model.normalization import select_norm_layer
from model import MODEL
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("base-PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = math.ceil((kernel_size - 1.0) / 2)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
padding_mode = "zeros"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, False)
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(nn.Sequential(
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=False),
))
multiple_now = min(2 ** num_conv, 8)
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
padding_mode=padding_mode))
self.conv_blocks = nn.ModuleList(sequence)
@staticmethod
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
bias=True, padding_mode: str = 'zeros'):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
if not use_spectral:
return conv
return nn.utils.spectral_norm(conv)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.conv_blocks:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
for layer in self.conv_blocks:
x = layer(x)
return x

25
model/GAN/wrapper.py Normal file
View File

@ -0,0 +1,25 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg):
super().__init__()
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
@staticmethod
def down_sample(x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results

View File

@ -3,3 +3,5 @@ import model.GAN.residual_generator
import model.GAN.TAHG
import model.GAN.UGATIT
import model.fewshot
import model.GAN.wrapper
import model.GAN.base

View File

@ -2,7 +2,7 @@ import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
import warnings
class _Registry:
def __init__(self, name):
@ -51,6 +51,12 @@ class _Registry:
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
for k in args:
assert isinstance(k, str)
if k.startswith("_"):
warnings.warn(f"got param start with `_`: {k}, will remove it")
args.pop(k)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')