raycv/engine/cyclegan.py

253 lines
11 KiB
Python

import itertools
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.debug(discriminator_a)
logger.debug(generator_a)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
},
"img": [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
with torch.no_grad():
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
return [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
tester = Engine(_step)
tester.logger = logger
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
def mkdir(engine):
img_output_dir = Path(config.output_dir) / "test_images"
if not img_output_dir.exists():
engine.logger.info(f"mkdir {img_output_dir}")
img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
torchvision.utils.save_image([img[i] for img in img_tensors],
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")