Compare commits

...

5 Commits

Author SHA1 Message Date
31aafb3470 UGATIT version 0.1 2020-08-24 06:51:42 +08:00
54b0799c48 print more info 2020-08-23 20:35:09 +08:00
9dfb887c86 add set_epoch methods 2020-08-23 20:34:16 +08:00
1e7f63cf85 add test image handler 2020-08-23 19:49:04 +08:00
35ab7ecd51 update cycle module 2020-08-22 20:57:03 +08:00
6 changed files with 123 additions and 89 deletions

View File

@ -1,4 +1,4 @@
name: selfie2anime-origin name: selfie2anime
engine: UGATIT engine: UGATIT
result_dir: ./result result_dir: ./result
max_pairs: 1000000 max_pairs: 1000000
@ -97,7 +97,7 @@ data:
std: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5]
test: test:
dataloader: dataloader:
batch_size: 4 batch_size: 8
shuffle: False shuffle: False
num_workers: 1 num_workers: 1
pin_memory: False pin_memory: False

View File

@ -19,11 +19,28 @@ from loss.gan import GANLoss
from model.weight_init import generation_init_weights from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid, fuse_attention_map from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
from util.handler import setup_common_handlers, setup_tensorboard_handler from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger): def get_trainer(config, logger):
generators = dict( generators = dict(
a2b=build_model(config.model.generator, config.distributed.model), a2b=build_model(config.model.generator, config.distributed.model),
@ -41,23 +58,14 @@ def get_trainer(config, logger):
logger.debug(discriminators["ga"]) logger.debug(discriminators["ga"])
logger.debug(generators["a2b"]) logger.debug(generators["a2b"])
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator) optimizers = dict(
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
config.optimizers.discriminator) d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
milestones_values = [ lr_schedulers = build_lr_schedulers(optimizers, config)
(0, config.optimizers.generator.lr), logger.info(f"build lr_schedulers:\n{lr_schedulers}")
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
@ -115,26 +123,26 @@ def get_trainer(config, logger):
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizer_g.zero_grad() optimizers["g"].zero_grad()
loss_g = dict() loss_g = dict()
for n in ["a", "b"]: for n in ["a", "b"]:
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward() sum(loss_g.values()).backward()
optimizer_g.step() optimizers["g"].step()
for generator in generators.values(): for generator in generators.values():
generator.apply(rho_clipper) generator.apply(rho_clipper)
for discriminator in discriminators.values(): for discriminator in discriminators.values():
discriminator.requires_grad_(True) discriminator.requires_grad_(True)
optimizer_d.zero_grad() optimizers["d"].zero_grad()
loss_d = dict() loss_d = dict()
for k in discriminators.keys(): for k in discriminators.keys():
n = k[-1] # "a" or "b" n = k[-1] # "a" or "b"
loss_d.update( loss_d.update(
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward() sum(loss_d.values()).backward()
optimizer_d.step() optimizers["d"].step()
for h in heatmap: for h in heatmap:
heatmap[h] = heatmap[h].detach() heatmap[h] = heatmap[h].detach()
@ -156,19 +164,19 @@ def get_trainer(config, logger):
trainer = Engine(_step) trainer = Engine(_step)
trainer.logger = logger trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, to_save = dict(trainer=trainer)
lr_scheduler_g=lr_scheduler_g) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators}) to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"], end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output): def output_transform(output):
loss = dict() loss = dict()
@ -184,31 +192,57 @@ def get_trainer(config, logger):
if tensorboard_handler is not None: if tensorboard_handler is not None:
tensorboard_handler.attach( tensorboard_handler.attach(
trainer, trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
) )
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine): def show_images(engine):
output = engine.state.output output = engine.state.output
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"] image_order = dict(
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"] a=["real_a", "fake_b", "rec_a", "id_a"],
b=["real_b", "fake_a", "rec_b", "id_b"]
)
output["img"]["generated"]["real_a"] = fuse_attention_map( output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map( output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
engine.state.iteration
)
tensorboard_handler.writer.add_image( with torch.no_grad():
"train/a", g = torch.Generator()
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]), g.manual_seed(config.misc.random_seed)
engine.state.iteration indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
) test_images = dict(
tensorboard_handler.writer.add_image( a=[[], [], [], []],
"train/b", b=[[], [], [], []]
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]), )
engine.state.iteration for i in indices:
) batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0]
for idx, im in enumerate(
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
test_images["a"][idx].append(im.cpu())
for idx, im in enumerate(
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
test_images["b"][idx].append(im.cpu())
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer return trainer
@ -225,6 +259,9 @@ def run(task, config, logger):
logger.info(f"train with dataset:\n{train_dataset}") logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger) trainer = get_trainer(config, logger)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try: try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception: except Exception:

View File

@ -16,6 +16,10 @@ def log_basic_info(logger, config):
logger.info(f"Train {config.name}") logger.info(f"Train {config.name}")
logger.info(f"- PyTorch version: {torch.__version__}") logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}") logger.info(f"- Ignite version: {ignite.__version__}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if idist.get_world_size() > 1: if idist.get_world_size() > 1:
logger.info("Distributed setting:\n") logger.info("Distributed setting:\n")
idist.show_config() idist.show_config()

View File

@ -118,17 +118,16 @@ class ResGenerator(nn.Module):
multiple = 2 ** i multiple = 2 ** i
submodules += [ submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, kernel_size=3, stride=2, padding=1, bias=use_bias),
stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2), norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
] ]
self.encoder = nn.Sequential(*submodules) self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels res_block_channels = num_down_sampling ** 2 * base_channels
self.res_blocks = nn.ModuleList( self.resnet_middle = nn.Sequential(
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
range(num_blocks)]) range(num_blocks)])
# up sampling # up sampling
submodules = [] submodules = []
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
def forward(self, x): def forward(self, x):
x = self.encoder(self.start_conv(x)) x = self.encoder(self.start_conv(x))
for rb in self.res_blocks: x = self.resnet_middle(x)
x = rb(x)
return self.end_conv(self.decoder(x)) return self.end_conv(self.decoder(x))
@MODEL.register_module() @MODEL.register_module()
class PatchDiscriminator(nn.Module): class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"): def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__() super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.' assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = _select_norm_layer(norm_type) norm_layer = _select_norm_layer(norm_type)

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
@ -16,7 +17,7 @@ def empty_cuda_cache(_):
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, metrics_to_print=None, end_event=None): to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
""" """
Helper method to setup trainer with common handlers. Helper method to setup trainer with common handlers.
1. TerminateOnNan 1. TerminateOnNan
@ -29,23 +30,21 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
:param clear_cuda_cache: :param clear_cuda_cache:
:param use_profiler: :param use_profiler:
:param to_save: :param to_save:
:param metrics_to_print:
:param end_event: :param end_event:
:param set_epoch_for_dist_sampler:
:return: :return:
""" """
if set_epoch_for_dist_sampler:
@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
# if train_sampler is not None: @trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
# if not isinstance(train_sampler, DistributedSampler): def print_info(engine):
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
#
# @trainer.on(Events.EPOCH_STARTED)
# def distrib_set_epoch(engine):
# train_sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan: if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
@ -63,20 +62,8 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
def log_intermediate_results(): def log_intermediate_results():
profiler.print_results(profiler.get_results()) profiler.print_results(profiler.get_results())
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
ProgressBar(ncols=0).attach(trainer, "all") ProgressBar(ncols=0).attach(trainer, "all")
if metrics_to_print is not None:
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
if m not in engine.state.metrics:
continue
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.debug(print_str)
if to_save is not None: if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name) n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
@ -87,11 +74,15 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if not checkpoint_path.exists(): if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}") engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler) checkpoint_handler)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None: if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event) @trainer.on(end_event)
def terminate(engine): def terminate(engine):
engine.terminate() engine.terminate()

View File

@ -5,6 +5,21 @@ import warnings
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
def attention_colored_map(attentions, size=None, cmap_name="jet"):
assert attentions.dim() == 4 and attentions.size(1) == 1
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if size is not None and attentions.size()[-2:] != size:
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
cmap = get_cmap(cmap_name)
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
""" """
@ -20,18 +35,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
if attentions.size(1) != 1: if attentions.size(1) != 1:
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
return images return images
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if images.size() != attentions.size():
attentions = interpolate(attentions, images.size()[-2:])
colored_attentions = torch.zeros_like(images)
cmap = get_cmap(cmap_name)
for i, at in enumerate(attentions):
ca = cmap(at[0].cpu().numpy())[:, :, :3]
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
return images * alpha + colored_attentions * (1 - alpha) return images * alpha + colored_attentions * (1 - alpha)