raycv/engine/UGATIT.py
2020-08-21 16:14:30 +08:00

250 lines
11 KiB
Python

from itertools import chain
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generators = dict(
a2b=build_model(config.model.generator, config.distributed.model),
b2a=build_model(config.model.generator, config.distributed.model),
)
discriminators = dict(
la=build_model(config.model.local_discriminator, config.distributed.model),
lb=build_model(config.model.local_discriminator, config.distributed.model),
ga=build_model(config.model.global_discriminator, config.distributed.model),
gb=build_model(config.model.global_discriminator, config.distributed.model),
)
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
logger.debug(discriminators["ga"])
logger.debug(generators["a2b"])
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
config.optimizers.discriminator)
milestones_values = [
(0, config.optimizers.generator.lr),
(config.data.train.scheduler.start, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
image_buffers = {
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
discriminators.keys()}
rho_clipper = RhoClipper(0, 1)
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g):
discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake)
pred_fake_l, cam_ld_pred = discriminator_l(fake)
return {
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
}
def cal_discriminator_loss(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2.
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake = dict()
cam_generator_pred = dict()
rec = dict()
identity = dict()
cam_identity_pred = dict()
heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
optimizer_g.zero_grad()
loss_g = dict()
for n in ["a", "b"]:
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward()
optimizer_g.step()
for generator in generators.values():
generator.apply(rho_clipper)
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizer_d.zero_grad()
loss_d = dict()
for k in discriminators.keys():
n = k[-1] # "a" or "b"
loss_d.update(
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward()
optimizer_d.step()
for h in heatmap:
heatmap[h] = heatmap[h].detach()
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": {
"heatmap": heatmap,
"generated": generated_img
}
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
lr_scheduler_g=lr_scheduler_g)
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
engine.state.iteration)
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")