190 lines
7.4 KiB
Python
190 lines
7.4 KiB
Python
from itertools import chain
|
|
from math import ceil
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
import torch
|
|
|
|
import ignite.distributed as idist
|
|
from ignite.engine import Events, Engine
|
|
from ignite.metrics import RunningAverage
|
|
from ignite.utils import convert_tensor
|
|
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
|
|
|
from model import MODEL
|
|
from util.image import make_2d_grid
|
|
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
|
from util.build import build_optimizer
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
def build_model(cfg):
|
|
cfg = OmegaConf.to_container(cfg)
|
|
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
|
model = MODEL.build_with(cfg)
|
|
if bn_to_sync_bn:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
return idist.auto_model(model)
|
|
|
|
|
|
def build_lr_schedulers(optimizers, config):
|
|
# TODO: support more scheduler type
|
|
g_milestones_values = [
|
|
(0, config.optimizers.generator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
d_milestones_values = [
|
|
(0, config.optimizers.discriminator.lr),
|
|
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
|
]
|
|
return dict(
|
|
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
|
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
|
)
|
|
|
|
|
|
class EngineKernel(object):
|
|
def __init__(self, config, logger):
|
|
self.config = config
|
|
self.logger = logger
|
|
self.generators, self.discriminators = self.build_models()
|
|
|
|
def build_models(self) -> (dict, dict):
|
|
raise NotImplemented
|
|
|
|
def to_save(self):
|
|
to_save = {}
|
|
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
|
|
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
|
|
return to_save
|
|
|
|
def setup_before_d(self):
|
|
raise NotImplemented
|
|
|
|
def setup_before_g(self):
|
|
raise NotImplemented
|
|
|
|
def forward(self, batch, inference=False) -> dict:
|
|
raise NotImplemented
|
|
|
|
def criterion_generators(self, batch, generated) -> dict:
|
|
raise NotImplemented
|
|
|
|
def criterion_discriminators(self, batch, generated) -> dict:
|
|
raise NotImplemented
|
|
|
|
def intermediate_images(self, batch, generated) -> dict:
|
|
"""
|
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
:param batch:
|
|
:param generated: dict of images
|
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
|
"""
|
|
raise NotImplemented
|
|
|
|
|
|
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
|
logger = logging.getLogger(config.name)
|
|
generators, discriminators = ek.generators, ek.discriminators
|
|
optimizers = dict(
|
|
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
|
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
|
)
|
|
logger.info("build optimizers", optimizers)
|
|
|
|
lr_schedulers = build_lr_schedulers(optimizers, config)
|
|
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
|
|
|
def _step(engine, batch):
|
|
batch = convert_tensor(batch, idist.device())
|
|
|
|
generated = ek.forward(batch)
|
|
|
|
ek.setup_before_g()
|
|
optimizers["g"].zero_grad()
|
|
loss_g = ek.criterion_generators(batch, generated)
|
|
sum(loss_g.values()).backward()
|
|
optimizers["g"].step()
|
|
|
|
ek.setup_before_d()
|
|
optimizers["d"].zero_grad()
|
|
loss_d = ek.criterion_discriminators(batch, generated)
|
|
sum(loss_d.values()).backward()
|
|
optimizers["d"].step()
|
|
|
|
return {
|
|
"loss": dict(g=loss_g, d=loss_d),
|
|
"img": ek.intermediate_images(batch, generated)
|
|
}
|
|
|
|
trainer = Engine(_step)
|
|
trainer.logger = logger
|
|
for lr_shd in lr_schedulers.values():
|
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
|
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
|
to_save = dict(trainer=trainer)
|
|
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
|
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
|
to_save.update(ek.to_save())
|
|
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
|
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
|
|
|
def output_transform(output):
|
|
loss = dict()
|
|
for tl in output["loss"]:
|
|
if isinstance(output["loss"][tl], dict):
|
|
for l in output["loss"][tl]:
|
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
|
else:
|
|
loss[tl] = output["loss"][tl]
|
|
return loss
|
|
|
|
pairs_per_iteration = config.data.train.dataloader.batch_size
|
|
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
|
if tensorboard_handler is not None:
|
|
tensorboard_handler.attach(
|
|
trainer,
|
|
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
|
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
|
)
|
|
|
|
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
|
def show_images(engine):
|
|
output = engine.state.output
|
|
test_images = {}
|
|
for k in output["img"]:
|
|
image_list = output["img"][k]
|
|
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
|
engine.state.iteration * pairs_per_iteration)
|
|
test_images[k] = []
|
|
for i in range(len(image_list)):
|
|
test_images[k].append([])
|
|
|
|
with torch.no_grad():
|
|
g = torch.Generator()
|
|
g.manual_seed(config.misc.random_seed)
|
|
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
|
for i in range(random_start, random_start + 10):
|
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
|
for k in batch:
|
|
batch[k] = batch[k].view(1, *batch[k].size())
|
|
generated = ek.forward(batch)
|
|
images = ek.intermediate_images(batch, generated)
|
|
|
|
for k in test_images:
|
|
for j in range(len(images[k])):
|
|
test_images[k][j].append(images[k][j])
|
|
for k in test_images:
|
|
tensorboard_handler.writer.add_image(
|
|
f"test/{k}",
|
|
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
|
engine.state.iteration * pairs_per_iteration
|
|
)
|
|
return trainer
|