UGATIT version 0.1

This commit is contained in:
budui 2020-08-24 06:51:42 +08:00
parent 54b0799c48
commit 31aafb3470
4 changed files with 90 additions and 109 deletions

View File

@ -1,4 +1,4 @@
name: selfie2anime-origin name: selfie2anime
engine: UGATIT engine: UGATIT
result_dir: ./result result_dir: ./result
max_pairs: 1000000 max_pairs: 1000000

View File

@ -4,7 +4,6 @@ from math import ceil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
@ -20,11 +19,28 @@ from loss.gan import GANLoss
from model.weight_init import generation_init_weights from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid, fuse_attention_map from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
from util.handler import setup_common_handlers, setup_tensorboard_handler from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger): def get_trainer(config, logger):
generators = dict( generators = dict(
a2b=build_model(config.model.generator, config.distributed.model), a2b=build_model(config.model.generator, config.distributed.model),
@ -42,23 +58,14 @@ def get_trainer(config, logger):
logger.debug(discriminators["ga"]) logger.debug(discriminators["ga"])
logger.debug(generators["a2b"]) logger.debug(generators["a2b"])
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator) optimizers = dict(
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
config.optimizers.discriminator) d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
milestones_values = [ lr_schedulers = build_lr_schedulers(optimizers, config)
(0, config.optimizers.generator.lr), logger.info(f"build lr_schedulers:\n{lr_schedulers}")
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
@ -116,26 +123,26 @@ def get_trainer(config, logger):
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizer_g.zero_grad() optimizers["g"].zero_grad()
loss_g = dict() loss_g = dict()
for n in ["a", "b"]: for n in ["a", "b"]:
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward() sum(loss_g.values()).backward()
optimizer_g.step() optimizers["g"].step()
for generator in generators.values(): for generator in generators.values():
generator.apply(rho_clipper) generator.apply(rho_clipper)
for discriminator in discriminators.values(): for discriminator in discriminators.values():
discriminator.requires_grad_(True) discriminator.requires_grad_(True)
optimizer_d.zero_grad() optimizers["d"].zero_grad()
loss_d = dict() loss_d = dict()
for k in discriminators.keys(): for k in discriminators.keys():
n = k[-1] # "a" or "b" n = k[-1] # "a" or "b"
loss_d.update( loss_d.update(
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward() sum(loss_d.values()).backward()
optimizer_d.step() optimizers["d"].step()
for h in heatmap: for h in heatmap:
heatmap[h] = heatmap[h].detach() heatmap[h] = heatmap[h].detach()
@ -157,19 +164,19 @@ def get_trainer(config, logger):
trainer = Engine(_step) trainer = Engine(_step)
trainer.logger = logger trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, to_save = dict(trainer=trainer)
lr_scheduler_g=lr_scheduler_g) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"], end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output): def output_transform(output):
loss = dict() loss = dict()
@ -185,46 +192,36 @@ def get_trainer(config, logger):
if tensorboard_handler is not None: if tensorboard_handler is not None:
tensorboard_handler.attach( tensorboard_handler.attach(
trainer, trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
) )
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine): def show_images(engine):
output = engine.state.output output = engine.state.output
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"] image_order = dict(
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"] a=["real_a", "fake_b", "rec_a", "id_a"],
b=["real_b", "fake_a", "rec_b", "id_b"]
)
output["img"]["generated"]["real_a"] = fuse_attention_map( output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map( output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
for k in "ab":
tensorboard_handler.writer.add_image( tensorboard_handler.writer.add_image(
"train/a", f"train/{k}",
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]), make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
engine.state.iteration engine.state.iteration
) )
tensorboard_handler.writer.add_image(
"train/b",
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
engine.state.iteration
)
with torch.no_grad(): with torch.no_grad():
g = torch.Generator() g = torch.Generator()
g.manual_seed(config.misc.random_seed) g.manual_seed(config.misc.random_seed)
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10] indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
test_images = dict(
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size, a=[[], [], [], []],
config.model.generator.img_size) b=[[], [], [], []]
fake = dict(a=empty_grid.clone(), b=empty_grid.clone()) )
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
config.model.generator.img_size),
b2a=torch.zeros(0, 1, config.model.generator.img_size,
config.model.generator.img_size))
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
for i in indices: for i in indices:
batch = convert_tensor(engine.state.test_dataset[i], idist.device()) batch = convert_tensor(engine.state.test_dataset[i], idist.device())
@ -234,27 +231,18 @@ def get_trainer(config, logger):
rec_a = generators["b2a"](fake_b)[0] rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0] rec_b = generators["a2b"](fake_a)[0]
fake["a"] = torch.cat([fake["a"], fake_a.cpu()]) for idx, im in enumerate(
fake["b"] = torch.cat([fake["b"], fake_b.cpu()]) [attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
real["a"] = torch.cat([real["a"], real_a.cpu()]) test_images["a"][idx].append(im.cpu())
real["b"] = torch.cat([real["b"], real_b.cpu()]) for idx, im in enumerate(
rec["a"] = torch.cat([rec["a"], rec_a.cpu()]) [attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
rec["b"] = torch.cat([rec["b"], rec_b.cpu()]) test_images["b"][idx].append(im.cpu())
for n in "ab":
heatmap["a2b"] = torch.cat( tensorboard_handler.writer.add_image(
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()]) f"test/{n}",
heatmap["b2a"] = torch.cat( make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()]) engine.state.iteration
tensorboard_handler.writer.add_image( )
"test/a",
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
engine.state.iteration
)
tensorboard_handler.writer.add_image(
"test/b",
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
engine.state.iteration
)
return trainer return trainer

View File

@ -17,7 +17,7 @@ def empty_cuda_cache(_):
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, metrics_to_print=None, end_event=None): to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
""" """
Helper method to setup trainer with common handlers. Helper method to setup trainer with common handlers.
1. TerminateOnNan 1. TerminateOnNan
@ -30,21 +30,21 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
:param clear_cuda_cache: :param clear_cuda_cache:
:param use_profiler: :param use_profiler:
:param to_save: :param to_save:
:param metrics_to_print:
:param end_event: :param end_event:
:param set_epoch_for_dist_sampler:
:return: :return:
""" """
if set_epoch_for_dist_sampler:
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
@trainer.on(Events.EPOCH_STARTED) @trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine): def distrib_set_epoch(engine):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED) @trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only() def print_info(engine):
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan: if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
@ -62,20 +62,8 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
def log_intermediate_results(): def log_intermediate_results():
profiler.print_results(profiler.get_results()) profiler.print_results(profiler.get_results())
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
ProgressBar(ncols=0).attach(trainer, "all") ProgressBar(ncols=0).attach(trainer, "all")
if metrics_to_print is not None:
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.debug(print_str)
if to_save is not None: if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name) n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
@ -86,6 +74,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if not checkpoint_path.exists(): if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}") engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,

View File

@ -5,6 +5,21 @@ import warnings
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
def attention_colored_map(attentions, size=None, cmap_name="jet"):
assert attentions.dim() == 4 and attentions.size(1) == 1
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if size is not None and attentions.size()[-2:] != size:
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
cmap = get_cmap(cmap_name)
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
""" """
@ -20,18 +35,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
if attentions.size(1) != 1: if attentions.size(1) != 1:
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
return images return images
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if images.size() != attentions.size():
attentions = interpolate(attentions, images.size()[-2:])
colored_attentions = torch.zeros_like(images)
cmap = get_cmap(cmap_name)
for i, at in enumerate(attentions):
ca = cmap(at[0].cpu().numpy())[:, :, :3]
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
return images * alpha + colored_attentions * (1 - alpha) return images * alpha + colored_attentions * (1 - alpha)