move common handler setup to util

This commit is contained in:
Ray Wong 2020-08-08 07:11:47 +08:00
parent 888a052f05
commit 8abd35467c
4 changed files with 183 additions and 85 deletions

View File

@ -61,6 +61,7 @@ optimizers:
data:
train:
buffer_size: 50
dataloader:
batch_size: 16
shuffle: True

View File

@ -3,58 +3,40 @@ from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from model import MODEL
from loss.gan import GANLoss
from util.distributed import auto_model
from model.weight_init import generation_init_weights
from model.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import Resumer
def _build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def _build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.debug(discriminator_a)
logger.debug(generator_a)
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
@ -75,16 +57,21 @@ def get_trainer(config, logger):
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
optimizer_g.zero_grad()
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
@ -96,17 +83,19 @@ def get_trainer(config, logger):
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
loss_d.backward()
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
@ -129,27 +118,25 @@ def get_trainer(config, logger):
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def print_log(engine):
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
f"loss_g={engine.state.metrics['loss_g']:.3f} "
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
@ -169,7 +156,6 @@ def get_trainer(config, logger):
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
@ -180,19 +166,9 @@ def get_trainer(config, logger):
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
# We need to close the logger with we are done
tb_logger.close()
@ -200,8 +176,8 @@ def get_trainer(config, logger):
def get_tester(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
@ -225,7 +201,7 @@ def get_tester(config, logger):
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
@ -248,15 +224,16 @@ def get_tester(config, logger):
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
try:
trainer.run(train_data_loader, max_epochs=1)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())

View File

@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import functools
from .registry import MODEL
@ -14,6 +15,60 @@ def _select_norm_layer(norm_type):
raise NotImplemented(f'normalization layer {norm_type} is not found')
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):

View File

@ -1,21 +1,86 @@
from pathlib import Path
import torch
from ignite.engine import Engine
from ignite.handlers import Checkpoint
import ignite.distributed as idist
from ignite.engine import Events
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler
class Resumer:
def __init__(self, to_load, checkpoint_path):
self.to_load = to_load
if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path)
def setup_common_handlers(
trainer,
output_dir=None,
stop_on_nan=True,
use_profiler=True,
print_interval_event=None,
metrics_to_print=None,
to_save=None,
resume_from=None,
save_interval_event=None,
**checkpoint_kwargs
):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
or sequence or a single tensor.
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
:param use_profiler:
:param print_interval_event:
:param metrics_to_print:
:param to_save:
:param resume_from:
:param save_interval_event:
:param checkpoint_kwargs:
:return:
"""
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.print_results(profiler.get_results())
# profiler.write_results(f"{output_dir}/time_profiling.csv")
if metrics_to_print is not None:
if print_interval_event is None:
raise ValueError(
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
)
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str)
if to_save is not None:
if resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(resume_from)
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
self.checkpoint_path = checkpoint_path
def __call__(self, engine: Engine):
if self.checkpoint_path is not None:
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
trainer.add_event_handler(save_interval_event, checkpoint_handler)