205 lines
8.4 KiB
Python
205 lines
8.4 KiB
Python
from itertools import chain
|
|
from math import ceil
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
|
|
import ignite.distributed as idist
|
|
from ignite.engine import Events, Engine
|
|
from ignite.metrics import RunningAverage
|
|
from ignite.utils import convert_tensor
|
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
|
|
|
from omegaconf import OmegaConf, read_write
|
|
|
|
import data
|
|
from loss.gan import GANLoss
|
|
from model.weight_init import generation_init_weights
|
|
from model.GAN.residual_generator import GANImageBuffer
|
|
from loss.I2I.edge_loss import EdgeLoss
|
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
|
from util.image import make_2d_grid
|
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
|
from util.build import build_model, build_optimizer
|
|
|
|
|
|
def build_lr_schedulers(optimizers, config):
|
|
g_milestones_values = [
|
|
(0, config.optimizers.generator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
d_milestones_values = [
|
|
(0, config.optimizers.discriminator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
return dict(
|
|
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
|
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
|
)
|
|
|
|
|
|
def get_trainer(config, logger):
|
|
generator = build_model(config.model.generator, config.distributed.model)
|
|
discriminators = dict(
|
|
a=build_model(config.model.discriminator, config.distributed.model),
|
|
b=build_model(config.model.discriminator, config.distributed.model),
|
|
)
|
|
generation_init_weights(generator)
|
|
for m in discriminators.values():
|
|
generation_init_weights(m)
|
|
|
|
logger.debug(discriminators["a"])
|
|
logger.debug(generator)
|
|
|
|
optimizers = dict(
|
|
g=build_optimizer(generator.parameters(), config.optimizers.generator),
|
|
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
|
)
|
|
logger.info(f"build optimizers:\n{optimizers}")
|
|
|
|
lr_schedulers = build_lr_schedulers(optimizers, config)
|
|
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
|
|
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
|
gan_loss_cfg.pop("weight")
|
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
|
|
|
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
|
|
edge_loss_cfg.pop("weight")
|
|
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
|
|
|
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
|
perceptual_loss_cfg.pop("weight")
|
|
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
|
|
|
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
|
|
|
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
|
|
|
def _step(engine, batch):
|
|
batch = convert_tensor(batch, idist.device())
|
|
real = dict(a=batch["a"], b=batch["b"])
|
|
edge = batch["edge"]
|
|
additional_info = batch["additional_info"]
|
|
content_img = torch.cat([edge, additional_info], dim=1)
|
|
fake = dict(
|
|
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
|
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
|
)
|
|
|
|
optimizers["g"].zero_grad()
|
|
loss_g = dict()
|
|
for d in "ab":
|
|
discriminators[d].requires_grad_(False)
|
|
pred_fake = discriminators[d](fake[d])
|
|
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
|
_, t = perceptual_loss(fake[d], real[d])
|
|
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
|
loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
|
|
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
|
sum(loss_g.values()).backward()
|
|
optimizers["g"].step()
|
|
|
|
for discriminator in discriminators.values():
|
|
discriminator.requires_grad_(True)
|
|
|
|
optimizers["d"].zero_grad()
|
|
loss_d = dict()
|
|
for k in discriminators.keys():
|
|
pred_real = discriminators[k](real[k])
|
|
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
|
|
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
|
|
gan_loss(pred_fake, False, is_discriminator=True)) / 2
|
|
sum(loss_d.values()).backward()
|
|
optimizers["d"].step()
|
|
|
|
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
|
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
|
return {
|
|
"loss": {
|
|
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
|
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
|
},
|
|
"img": generated_img
|
|
}
|
|
|
|
trainer = Engine(_step)
|
|
trainer.logger = logger
|
|
for lr_shd in lr_schedulers.values():
|
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
|
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
|
|
|
to_save = dict(trainer=trainer)
|
|
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
|
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
|
to_save.update({"generator": generator})
|
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
|
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
|
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
|
|
|
def output_transform(output):
|
|
loss = dict()
|
|
for tl in output["loss"]:
|
|
if isinstance(output["loss"][tl], dict):
|
|
for l in output["loss"][tl]:
|
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
|
else:
|
|
loss[tl] = output["loss"][tl]
|
|
return loss
|
|
|
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
|
if tensorboard_handler is not None:
|
|
tensorboard_handler.attach(
|
|
trainer,
|
|
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
|
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
|
)
|
|
|
|
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
|
def show_images(engine):
|
|
output = engine.state.output
|
|
image_order = dict(
|
|
a=["real_a", "fake_a"],
|
|
b=["real_b", "fake_b"]
|
|
)
|
|
for k in "ab":
|
|
tensorboard_handler.writer.add_image(
|
|
f"train/{k}",
|
|
make_2d_grid([output["img"][o] for o in image_order[k]]),
|
|
engine.state.iteration
|
|
)
|
|
|
|
return trainer
|
|
|
|
|
|
def run(task, config, logger):
|
|
assert torch.backends.cudnn.enabled
|
|
torch.backends.cudnn.benchmark = True
|
|
logger.info(f"start task {task}")
|
|
with read_write(config):
|
|
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
|
|
|
if task == "train":
|
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
|
logger.info(f"train with dataset:\n{train_dataset}")
|
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
|
trainer = get_trainer(config, logger)
|
|
if idist.get_rank() == 0:
|
|
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
|
trainer.state.test_dataset = test_dataset
|
|
try:
|
|
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
|
except Exception:
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
else:
|
|
return NotImplemented(f"invalid task: {task}")
|