Compare commits
5 Commits
58ed4524bf
...
31aafb3470
| Author | SHA1 | Date | |
|---|---|---|---|
| 31aafb3470 | |||
| 54b0799c48 | |||
| 9dfb887c86 | |||
| 1e7f63cf85 | |||
| 35ab7ecd51 |
@ -1,4 +1,4 @@
|
||||
name: selfie2anime-origin
|
||||
name: selfie2anime
|
||||
engine: UGATIT
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
@ -97,7 +97,7 @@ data:
|
||||
std: [0.5, 0.5, 0.5]
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
batch_size: 8
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
|
||||
121
engine/UGATIT.py
121
engine/UGATIT.py
@ -19,11 +19,28 @@ from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid, fuse_attention_map
|
||||
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generators = dict(
|
||||
a2b=build_model(config.model.generator, config.distributed.model),
|
||||
@ -41,23 +58,14 @@ def get_trainer(config, logger):
|
||||
logger.debug(discriminators["ga"])
|
||||
logger.debug(generators["a2b"])
|
||||
|
||||
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
|
||||
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
|
||||
config.optimizers.discriminator)
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
@ -115,26 +123,26 @@ def get_trainer(config, logger):
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
optimizers["g"].step()
|
||||
for generator in generators.values():
|
||||
generator.apply(rho_clipper)
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizer_d.step()
|
||||
optimizers["d"].step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
@ -156,19 +164,19 @@ def get_trainer(config, logger):
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
|
||||
lr_scheduler_g=lr_scheduler_g)
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
|
||||
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
@ -184,31 +192,57 @@ def get_trainer(config, logger):
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_b", "rec_a", "id_a"],
|
||||
b=["real_b", "fake_a", "rec_b", "id_b"]
|
||||
)
|
||||
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"train/{k}",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/a",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/b",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in indices:
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
||||
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
||||
rec_a = generators["b2a"](fake_b)[0]
|
||||
rec_b = generators["a2b"](fake_a)[0]
|
||||
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
|
||||
test_images["a"][idx].append(im.cpu())
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
|
||||
test_images["b"][idx].append(im.cpu())
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
@ -225,6 +259,9 @@ def run(task, config, logger):
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
|
||||
4
main.py
4
main.py
@ -16,6 +16,10 @@ def log_basic_info(logger, config):
|
||||
logger.info(f"Train {config.name}")
|
||||
logger.info(f"- PyTorch version: {torch.__version__}")
|
||||
logger.info(f"- Ignite version: {ignite.__version__}")
|
||||
logger.info(f"- CUDA version: {torch.version.cuda}")
|
||||
logger.info(f"- cuDNN version: {torch.backends.cudnn.version()}")
|
||||
logger.info(f"- GPU type: {torch.cuda.get_device_name(0)}")
|
||||
logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
if idist.get_world_size() > 1:
|
||||
logger.info("Distributed setting:\n")
|
||||
idist.show_config()
|
||||
|
||||
@ -118,17 +118,16 @@ class ResGenerator(nn.Module):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3,
|
||||
stride=2, padding=1, bias=use_bias),
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
range(num_blocks)])
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
for rb in self.res_blocks:
|
||||
x = rb(x)
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@ -16,7 +17,7 @@ def empty_cuda_cache(_):
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, metrics_to_print=None, end_event=None):
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
|
||||
"""
|
||||
Helper method to setup trainer with common handlers.
|
||||
1. TerminateOnNan
|
||||
@ -29,23 +30,21 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
:param clear_cuda_cache:
|
||||
:param use_profiler:
|
||||
:param to_save:
|
||||
:param metrics_to_print:
|
||||
:param end_event:
|
||||
:param set_epoch_for_dist_sampler:
|
||||
:return:
|
||||
"""
|
||||
if set_epoch_for_dist_sampler:
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def distrib_set_epoch(engine):
|
||||
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
# if train_sampler is not None:
|
||||
# if not isinstance(train_sampler, DistributedSampler):
|
||||
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
|
||||
#
|
||||
# @trainer.on(Events.EPOCH_STARTED)
|
||||
# def distrib_set_epoch(engine):
|
||||
# train_sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
|
||||
def print_info(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
@ -63,20 +62,8 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
|
||||
|
||||
ProgressBar(ncols=0).attach(trainer, "all")
|
||||
|
||||
if metrics_to_print is not None:
|
||||
@trainer.on(print_interval_event)
|
||||
def print_interval(engine):
|
||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||
for m in metrics_to_print:
|
||||
if m not in engine.state.metrics:
|
||||
continue
|
||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||
engine.logger.debug(print_str)
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
||||
@ -87,11 +74,15 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
trainer.logger.info(f"load state_dict for {ckp.keys()}")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler)
|
||||
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||
if end_event is not None:
|
||||
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||
|
||||
@trainer.on(end_event)
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
@ -5,6 +5,21 @@ import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
|
||||
def attention_colored_map(attentions, size=None, cmap_name="jet"):
|
||||
assert attentions.dim() == 4 and attentions.size(1) == 1
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if size is not None and attentions.size()[-2:] != size:
|
||||
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
||||
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
|
||||
cmap = get_cmap(cmap_name)
|
||||
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
|
||||
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
"""
|
||||
|
||||
@ -20,18 +35,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if images.size() != attentions.size():
|
||||
attentions = interpolate(attentions, images.size()[-2:])
|
||||
colored_attentions = torch.zeros_like(images)
|
||||
cmap = get_cmap(cmap_name)
|
||||
for i, at in enumerate(attentions):
|
||||
ca = cmap(at[0].cpu().numpy())[:, :, :3]
|
||||
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user