Compare commits

...

5 Commits

7 changed files with 201 additions and 95 deletions

View File

@ -61,6 +61,7 @@ optimizers:
data: data:
train: train:
buffer_size: 50
dataloader: dataloader:
batch_size: 16 batch_size: 16
shuffle: True shuffle: True

View File

@ -3,59 +3,41 @@ from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
import torchvision.utils import torchvision.utils
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage from ignite.metrics import RunningAverage
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.utils import convert_tensor from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf from omegaconf import OmegaConf
import data import data
from model import MODEL
from loss.gan import GANLoss from loss.gan import GANLoss
from util.distributed import auto_model from model.weight_init import generation_init_weights
from model.residual_generator import GANImageBuffer
from util.image import make_2d_grid from util.image import make_2d_grid
from util.handler import Resumer from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def _build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def _build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)
def get_trainer(config, logger): def get_trainer(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model) generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model) generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = _build_model(config.model.discriminator, config.distributed.model) discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = _build_model(config.model.discriminator, config.distributed.model) discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.debug(discriminator_a) logger.debug(discriminator_a)
logger.debug(generator_a) logger.debug(generator_a)
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator) config.optimizers.generator)
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator) config.optimizers.discriminator)
milestones_values = [ milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr), (config.data.train.scheduler.start, config.optimizers.generator.lr),
@ -75,16 +57,21 @@ def get_trainer(config, logger):
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch): def _step(engine, batch):
batch = convert_tensor(batch, idist.device()) batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"] real_a, real_b = batch["a"], batch["b"]
optimizer_g.zero_grad()
fake_b = generator_a(real_a) # G_A(A) fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A)) rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B) fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B)) rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict( loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
@ -96,17 +83,19 @@ def get_trainer(config, logger):
sum(loss_g.values()).backward() sum(loss_g.values()).backward()
optimizer_g.step() optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d_a = dict( loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True), fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
) )
(sum(loss_d_a.values()) * 0.5).backward()
loss_d_b = dict( loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True), fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
) )
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2 (sum(loss_d_b.values()) * 0.5).backward()
loss_d.backward()
optimizer_d.step() optimizer_d.step()
return { return {
@ -129,27 +118,25 @@ def get_trainer(config, logger):
trainer.logger = logger trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def print_log(engine):
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
f"loss_g={engine.state.metrics['loss_g']:.3f} "
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
to_save = dict( to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
) )
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from)) setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None) metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler) resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0: if idist.get_rank() == 0:
# Create a logger # Create a logger
@ -169,7 +156,6 @@ def get_trainer(config, logger):
), ),
event_name=Events.ITERATION_COMPLETED(every=50) event_name=Events.ITERATION_COMPLETED(every=50)
) )
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
tb_logger.attach( tb_logger.attach(
trainer, trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
@ -180,28 +166,18 @@ def get_trainer(config, logger):
def show_images(engine): def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
# Create an object of the profiler and attach an engine to it @trainer.on(Events.COMPLETED)
profiler = BasicTimeProfiler() @idist.one_rank_only()
profiler.attach(trainer) def _():
# We need to close the logger with we are done
@trainer.on(Events.EPOCH_COMPLETED(once=1)) tb_logger.close()
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
# We need to close the logger with we are done
tb_logger.close()
return trainer return trainer
def get_tester(config, logger): def get_tester(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model) generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model) generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch): def _step(engine, batch):
batch = convert_tensor(batch, idist.device()) batch = convert_tensor(batch, idist.device())
@ -225,7 +201,7 @@ def get_tester(config, logger):
if idist.get_rank == 0: if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester) ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b) to_load = dict(generator_a=generator_a, generator_b=generator_b)
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from)) setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED) @tester.on(Events.STARTED)
@idist.one_rank_only() @idist.one_rank_only()
@ -248,19 +224,21 @@ def get_tester(config, logger):
def run(task, config, logger): def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}") logger.info(f"start task {task}")
if task == "train": if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset) train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}") logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger) trainer = get_trainer(config, logger)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
try: try:
trainer.run(train_data_loader, max_epochs=1) trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception: except Exception:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())
elif task == "test": elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset) test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}") logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)

19
main.py
View File

@ -29,14 +29,16 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
if setup_output_dir: if setup_output_dir:
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
config.output_dir = str(output_dir) config.output_dir = str(output_dir)
if idist.get_rank() == 0: if output_dir.exists():
if not output_dir.exists(): assert not any(output_dir.iterdir()), "output_dir must be empty"
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
logger.info(f"mkdir -p {output_dir}") logger.info(f"mkdir -p {output_dir}")
logger.info(f"output path: {config.output_dir}") logger.info(f"output path: {config.output_dir}")
if backup_config: if backup_config and idist.get_rank() == 0:
with open(output_dir / "config.yml", "w+") as f: with open(output_dir / "config.yml", "w+") as f:
print(config.pretty(), file=f) print(config.pretty(), file=f)
OmegaConf.set_readonly(config, True) OmegaConf.set_readonly(config, True)
@ -46,7 +48,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
def run(task, config: str, *omega_options, **kwargs): def run(task, config: str, *omega_options, **kwargs):
omega_options = [str(o) for o in omega_options] omega_options = [str(o) for o in omega_options]
conf = OmegaConf.merge(OmegaConf.load(config), OmegaConf.from_cli(omega_options)) cli_conf = OmegaConf.from_cli(omega_options)
if len(cli_conf) > 0:
print(cli_conf.pretty())
conf = OmegaConf.merge(OmegaConf.load(config), cli_conf)
backend = kwargs.get("backend", "nccl") backend = kwargs.get("backend", "nccl")
backup_config = kwargs.get("backup_config", False) backup_config = kwargs.get("backup_config", False)
setup_output_dir = kwargs.get("setup_output_dir", False) setup_output_dir = kwargs.get("setup_output_dir", False)

View File

@ -1,3 +1,4 @@
import torch
import torch.nn as nn import torch.nn as nn
import functools import functools
from .registry import MODEL from .registry import MODEL
@ -14,6 +15,60 @@ def _select_norm_layer(norm_type):
raise NotImplemented(f'normalization layer {norm_type} is not found') raise NotImplemented(f'normalization layer {norm_type} is not found')
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
@MODEL.register_module() @MODEL.register_module()
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False): def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):

View File

@ -67,4 +67,6 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
# only normal distribution applies. # only normal distribution applies.
normal_init(m, 1.0, init_gain) normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func) module.apply(init_func)

View File

@ -1,21 +1,86 @@
from pathlib import Path from pathlib import Path
import torch import torch
from ignite.engine import Engine
from ignite.handlers import Checkpoint import ignite.distributed as idist
from ignite.engine import Events
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler
class Resumer: def setup_common_handlers(
def __init__(self, to_load, checkpoint_path): trainer,
self.to_load = to_load output_dir=None,
if checkpoint_path is not None: stop_on_nan=True,
checkpoint_path = Path(checkpoint_path) use_profiler=True,
if not checkpoint_path.exists(): print_interval_event=None,
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found") metrics_to_print=None,
self.checkpoint_path = checkpoint_path to_save=None,
resume_from=None,
save_interval_event=None,
**checkpoint_kwargs
):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
or sequence or a single tensor.
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
:param use_profiler:
:param print_interval_event:
:param metrics_to_print:
:param to_save:
:param resume_from:
:param save_interval_event:
:param checkpoint_kwargs:
:return:
"""
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
def __call__(self, engine: Engine): if use_profiler:
if self.checkpoint_path is not None: # Create an object of the profiler and attach an engine to it
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu") profiler = BasicTimeProfiler()
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp) profiler.attach(trainer)
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.print_results(profiler.get_results())
# profiler.write_results(f"{output_dir}/time_profiling.csv")
if metrics_to_print is not None:
if print_interval_event is None:
raise ValueError(
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
)
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str)
if to_save is not None:
if resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
trainer.add_event_handler(save_interval_event, checkpoint_handler)

View File

@ -119,9 +119,9 @@ class Registry(_Registry):
return self._module_dict.get(key, None) return self._module_dict.get(key, None)
def _register_module(self, module_class, module_name=None, force=False): def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class): # if not inspect.isclass(module_class):
raise TypeError('module must be a class, ' # raise TypeError('module must be a class, '
f'but got {type(module_class)}') # f'but got {type(module_class)}')
if module_name is None: if module_name is None:
module_name = module_class.__name__ module_name = module_class.__name__