raycv/engine/base/i2i.py
2020-09-03 09:34:38 +08:00

190 lines
7.4 KiB
Python

from itertools import chain
from math import ceil
from pathlib import Path
import logging
import torch
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from model import MODEL
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_optimizer
from omegaconf import OmegaConf
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
def build_lr_schedulers(optimizers, config):
# TODO: support more scheduler type
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
class EngineKernel(object):
def __init__(self, config, logger):
self.config = config
self.logger = logger
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
raise NotImplemented
def to_save(self):
to_save = {}
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_before_d(self):
raise NotImplemented
def setup_before_g(self):
raise NotImplemented
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
logger = logging.getLogger(config.name)
generators, discriminators = ek.generators, ek.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info("build optimizers", optimizers)
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = ek.forward(batch)
ek.setup_before_g()
optimizers["g"].zero_grad()
loss_g = ek.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
ek.setup_before_d()
optimizers["d"].zero_grad()
loss_d = ek.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return {
"loss": dict(g=loss_g, d=loss_d),
"img": ek.intermediate_images(batch, generated)
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(ek.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = ek.forward(batch)
images = ek.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
engine.state.iteration * pairs_per_iteration
)
return trainer