raycv/engine/cyclegan.py

275 lines
11 KiB
Python

import itertools
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from model import MODEL
from loss.gan import GANLoss
from util.distributed import auto_model
from util.image import make_2d_grid
from util.handler import Resumer
def _build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def _build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)
def get_trainer(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
logger.debug(discriminator_a)
logger.debug(generator_a)
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
optimizer_g.zero_grad()
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
sum(loss_g.values()).backward()
optimizer_g.step()
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
)
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
)
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
loss_d.backward()
optimizer_d.step()
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
},
"img": [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def print_log(engine):
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
f"loss_g={engine.state.metrics['loss_g']:.3f} "
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
with torch.no_grad():
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
return [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
tester = Engine(_step)
tester.logger = logger
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
@tester.on(Events.STARTED)
@idist.one_rank_only()
def mkdir(engine):
img_output_dir = Path(config.output_dir) / "test_images"
if not img_output_dir.exists():
engine.logger.info(f"mkdir {img_output_dir}")
img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
torchvision.utils.save_image([img[i] for img in img_tensors],
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
nrow=len(img_tensors))
return tester
def run(task, config, logger):
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
try:
trainer.run(train_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")