raycv/engine/base/i2i.py
2020-09-05 10:33:35 +08:00

265 lines
10 KiB
Python

from itertools import chain
import logging
from pathlib import Path
import torch
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
import data
def build_lr_schedulers(optimizers, config):
# TODO: support more scheduler type
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
class TestEngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators = self.build_generators()
def build_generators(self) -> dict:
raise NotImplemented
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
raise NotImplemented
class EngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
raise NotImplemented
def to_save(self):
to_save = {}
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_before_d(self):
raise NotImplemented
def setup_before_g(self):
raise NotImplemented
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = kernel.forward(batch)
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
if engine.state.iteration % image_per_iteration == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(kernel.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(basic_image_event)
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
engine.state.iteration * pairs_per_iteration
)
return trainer
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
tester = Engine(_step)
tester.logger = logger
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run_kernel(task, config, kernel):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
with read_write(config):
config.iterations_per_epoch = len(train_data_loader)
trainer = get_trainer(config, kernel)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")