commit f7843de45d59d4a8c4cc73ed13ac9b19aac9060b Author: Ray Wong Date: Fri Aug 7 09:48:09 2020 +0800 base code for pytorch distributed, add cyclegan diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d759cd1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,243 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### JupyterNotebooks template +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### Linux template +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..1322a6f --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/../../../../../../:\Users\wr\Code\raycv\.idea/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..cebeb10 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..1b9173d --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..0e62d99 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/raycv.iml b/.idea/raycv.iml new file mode 100644 index 0000000..9781a97 --- /dev/null +++ b/.idea/raycv.iml @@ -0,0 +1,11 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..9661ac7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/configs/synthesizers/cyclegan.yml b/configs/synthesizers/cyclegan.yml new file mode 100644 index 0000000..13b30fa --- /dev/null +++ b/configs/synthesizers/cyclegan.yml @@ -0,0 +1,102 @@ +name: horse2zebra +engine: cyclegan +result_dir: ./result +max_iteration: 18000 + +distributed: + model: + # broadcast_buffers: False + +misc: + random_seed: 1004 + +checkpoints: + interval: 2000 + +log: + logger: + level: 20 # DEBUG(10) INFO(20) + +model: + generator: + _type: ResGenerator + in_channels: 3 + out_channels: 3 + base_channels: 64 + num_blocks: 9 + padding_mode: reflect + norm_type: IN + use_dropout: False + discriminator: + _type: PatchDiscriminator + _distributed: + bn_to_syncbn: True + in_channels: 3 + base_channels: 64 + num_conv: 3 + norm_type: BN + +loss: + gan: + loss_type: lsgan + weight: 1.0 + real_label_val: 1.0 + fake_label_val: 0.0 + cycle: + level: 1 + weight: 10.0 + id: + level: 1 + weight: 0 + +optimizers: + generator: + _type: Adam + lr: 2e-4 + betas: [0.5, 0.999] + discriminator: + _type: Adam + lr: 2e-4 + betas: [0.5, 0.999] + +data: + train: + dataloader: + batch_size: 16 + shuffle: True + num_workers: 4 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/horse2zebra/trainA" + root_b: "/data/i2i/horse2zebra/trainB" + random_pair: True + pipeline: + - Load + - Resize: + size: [286, 286] + - RandomCrop: + size: [256, 256] + - RandomHorizontalFlip + - ToTensor + scheduler: + start: 9000 + target_lr: 0 + test: + dataloader: + batch_size: 4 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/horse2zebra/testA" + root_b: "/data/i2i/horse2zebra/testB" + random_pair: False + pipeline: + - Load + - Resize: + size: [256, 256] + - ToTensor diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..7dcc718 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,4 @@ +import data.dataset +import data.transform +from data.registry import DATASET, TRANSFORM + diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000..87c0789 --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,73 @@ +import os +import pickle + +import torch +from torch.utils.data import Dataset +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + +import lmdb + +from .transform import transform_pipeline +from .registry import DATASET + + +class LMDBDataset(Dataset): + def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs): + self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly, + **lmdb_kwargs) + self.output_transform = output_transform + with self.db.begin(write=False) as txn: + self._len = pickle.loads(txn.get(b"__len__")) + + def __len__(self): + return self._len + + def __getitem__(self, idx): + with self.db.begin(write=False) as txn: + sample = pickle.loads(txn.get("{}".format(idx).encode())) + if self.output_transform is not None: + sample = self.output_transform(sample) + return sample + + +@DATASET.register_module() +class SingleFolderDataset(Dataset): + def __init__(self, root, pipeline): + assert os.path.isdir(root) + self.root = root + samples = [] + for r, _, fns in sorted(os.walk(self.root, followlinks=True)): + for fn in sorted(fns): + path = os.path.join(r, fn) + if has_file_allowed_extension(path, IMG_EXTENSIONS): + samples.append(path) + self.samples = samples + self.pipeline = transform_pipeline(pipeline) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.pipeline(self.samples[idx]) + + def __repr__(self): + return f"" + + +@DATASET.register_module() +class GenerationUnpairedDataset(Dataset): + def __init__(self, root_a, root_b, random_pair, pipeline): + self.A = SingleFolderDataset(root_a, pipeline) + self.B = SingleFolderDataset(root_b, pipeline) + self.random_pair = random_pair + + def __getitem__(self, idx): + a_idx = idx % len(self.A) + b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() + return dict(a=self.A[a_idx], b=self.B[b_idx]) + + def __len__(self): + return max(len(self.A), len(self.B)) + + def __repr__(self): + return f"\nPipeline:\n{self.A.pipeline}" diff --git a/data/registry.py b/data/registry.py new file mode 100644 index 0000000..f9c71bd --- /dev/null +++ b/data/registry.py @@ -0,0 +1,4 @@ +from util.registry import Registry + +DATASET = Registry("dataset") +TRANSFORM = Registry("transform") diff --git a/data/transform.py b/data/transform.py new file mode 100644 index 0000000..1aa6c06 --- /dev/null +++ b/data/transform.py @@ -0,0 +1,34 @@ +from torchvision import transforms +from torchvision.datasets.folder import default_loader + +from .registry import TRANSFORM + +# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html +_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize", + "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", + "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", + "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", + "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", + "RandomErasing"] + +for vtt in _VALID_TORCHVISION_TRANSFORMS: + TRANSFORM.register_module(module=getattr(transforms, vtt)) + + +@TRANSFORM.register_module() +class Load: + def __init__(self, loader=default_loader): + self.loader = loader + + def __call__(self, image_path): + return self.loader(image_path) + + def __repr__(self): + return self.__class__.__name__ + "()" + + +def transform_pipeline(pipeline_description): + if len(pipeline_description) == 0: + return lambda x: x + transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description] + return transforms.Compose(transform_list) diff --git a/docs/run.md b/docs/run.md new file mode 100644 index 0000000..e69de29 diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engine/cyclegan.py b/engine/cyclegan.py new file mode 100644 index 0000000..7391d19 --- /dev/null +++ b/engine/cyclegan.py @@ -0,0 +1,274 @@ +import itertools +from pathlib import Path + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.utils + +import ignite.distributed as idist +from ignite.engine import Events, Engine +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear +from ignite.metrics import RunningAverage +from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan +from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar +from ignite.utils import convert_tensor +from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler + +from omegaconf import OmegaConf + +import data +from model import MODEL +from loss.gan import GANLoss +from util.distributed import auto_model +from util.image import make_2d_grid +from util.handler import Resumer + + +def _build_model(cfg, distributed_args=None): + cfg = OmegaConf.to_container(cfg) + model_distributed_config = cfg.pop("_distributed", dict()) + model = MODEL.build_with(cfg) + + if model_distributed_config.get("bn_to_syncbn"): + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args + return auto_model(model, **distributed_args) + + +def _build_optimizer(params, cfg): + assert "_type" in cfg + cfg = OmegaConf.to_container(cfg) + optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg) + return idist.auto_optim(optimizer) + + +def get_trainer(config, logger): + generator_a = _build_model(config.model.generator, config.distributed.model) + generator_b = _build_model(config.model.generator, config.distributed.model) + discriminator_a = _build_model(config.model.discriminator, config.distributed.model) + discriminator_b = _build_model(config.model.discriminator, config.distributed.model) + logger.debug(discriminator_a) + logger.debug(generator_a) + + optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), + config.optimizers.generator) + optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()), + config.optimizers.discriminator) + + milestones_values = [ + (config.data.train.scheduler.start, config.optimizers.generator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr), + ] + lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) + + milestones_values = [ + (config.data.train.scheduler.start, config.optimizers.discriminator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr), + ] + lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() + id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() + + def _step(engine, batch): + batch = convert_tensor(batch, idist.device()) + real_a, real_b = batch["a"], batch["b"] + + optimizer_g.zero_grad() + fake_b = generator_a(real_a) # G_A(A) + rec_a = generator_b(fake_b) # G_B(G_A(A)) + fake_a = generator_b(real_b) # G_B(B) + rec_b = generator_a(fake_a) # G_A(G_B(B)) + + loss_g = dict( + id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B) + id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A) + cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a), + cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b), + gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True), + gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True) + ) + sum(loss_g.values()).backward() + optimizer_g.step() + + optimizer_d.zero_grad() + loss_d_a = dict( + real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), + fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True), + ) + loss_d_b = dict( + real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), + fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True), + ) + loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2 + loss_d.backward() + optimizer_d.step() + + return { + "loss": { + "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, + "d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a}, + "d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b}, + }, + "img": [ + real_a.detach(), + fake_b.detach(), + rec_a.detach(), + real_b.detach(), + fake_a.detach(), + rec_b.detach() + ] + } + + trainer = Engine(_step) + trainer.logger = logger + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) + trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b") + + @trainer.on(Events.ITERATION_COMPLETED(every=10)) + def print_log(engine): + engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]" + f"loss_g={engine.state.metrics['loss_g']:.3f} " + f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} " + f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ") + + to_save = dict( + generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a, + discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, + lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g + ) + + trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from)) + checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None) + trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler) + + if idist.get_rank() == 0: + # Create a logger + tb_logger = TensorboardLogger(log_dir=config.output_dir) + tb_writer = tb_logger.writer + + # Attach the logger to the trainer to log training loss at each iteration + def global_step_transform(*args, **kwargs): + return trainer.state.iteration + + tb_logger.attach( + trainer, + log_handler=OutputHandler( + tag="loss", + metric_names=["loss_g", "loss_d_a", "loss_d_b"], + global_step_transform=global_step_transform, + ), + event_name=Events.ITERATION_COMPLETED(every=50) + ) + # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration + tb_logger.attach( + trainer, + log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), + event_name=Events.ITERATION_STARTED(every=50) + ) + + @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) + def show_images(engine): + tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration) + + # Create an object of the profiler and attach an engine to it + profiler = BasicTimeProfiler() + profiler.attach(trainer) + + @trainer.on(Events.EPOCH_COMPLETED(once=1)) + @idist.one_rank_only() + def log_intermediate_results(): + profiler.print_results(profiler.get_results()) + + @trainer.on(Events.COMPLETED) + @idist.one_rank_only() + def _(): + profiler.write_results(f"{config.output_dir}/time_profiling.csv") + # We need to close the logger with we are done + tb_logger.close() + + return trainer + + +def get_tester(config, logger): + generator_a = _build_model(config.model.generator, config.distributed.model) + generator_b = _build_model(config.model.generator, config.distributed.model) + + def _step(engine, batch): + batch = convert_tensor(batch, idist.device()) + real_a, real_b = batch["a"], batch["b"] + with torch.no_grad(): + fake_b = generator_a(real_a) # G_A(A) + rec_a = generator_b(fake_b) # G_B(G_A(A)) + fake_a = generator_b(real_b) # G_B(B) + rec_b = generator_a(fake_a) # G_A(G_B(B)) + return [ + real_a.detach(), + fake_b.detach(), + rec_a.detach(), + real_b.detach(), + fake_a.detach(), + rec_b.detach() + ] + + tester = Engine(_step) + tester.logger = logger + if idist.get_rank == 0: + ProgressBar(ncols=0).attach(tester) + to_load = dict(generator_a=generator_a, generator_b=generator_b) + tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from)) + + @tester.on(Events.STARTED) + @idist.one_rank_only() + def mkdir(engine): + img_output_dir = Path(config.output_dir) / "test_images" + if not img_output_dir.exists(): + engine.logger.info(f"mkdir {img_output_dir}") + img_output_dir.mkdir() + + @tester.on(Events.ITERATION_COMPLETED) + def save_images(engine): + img_tensors = engine.state.output + batch_size = img_tensors[0].size(0) + for i in range(batch_size): + torchvision.utils.save_image([img[i] for img in img_tensors], + Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg", + nrow=len(img_tensors)) + + return tester + + +def run(task, config, logger): + logger.info(f"start task {task}") + if task == "train": + train_dataset = data.DATASET.build_with(config.data.train.dataset) + logger.info(f"train with dataset:\n{train_dataset}") + train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) + trainer = get_trainer(config, logger) + trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) + try: + trainer.run(train_data_loader, max_epochs=1) + except Exception: + import traceback + print(traceback.format_exc()) + elif task == "test": + test_dataset = data.DATASET.build_with(config.data.test.dataset) + logger.info(f"test with dataset:\n{test_dataset}") + test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) + tester = get_tester(config, logger) + try: + tester.run(test_data_loader, max_epochs=1) + except Exception: + import traceback + print(traceback.format_exc()) + else: + return NotImplemented(f"invalid task: {task}") diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..9d11461 --- /dev/null +++ b/environment.yml @@ -0,0 +1,22 @@ +name: raycv +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.8 + - numpy + - ipython + - tqdm + - pyyaml + - pytorch=1.6.* + - torchvision + - cudatoolkit=10.2 + - ignite + - tensorboard + - omegaconf + - python-lmdb + - fire + # - opencv + # - jupyterlab + diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/gan.py b/loss/gan.py new file mode 100644 index 0000000..5e30bc4 --- /dev/null +++ b/loss/gan.py @@ -0,0 +1,39 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class GANLoss(nn.Module): + def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0): + super().__init__() + assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"] + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + self.loss_type = loss_type + + def forward(self, prediction, target_is_real: bool, is_discriminator=False): + """ + gan loss forward + :param prediction: network prediction + :param target_is_real: whether the target is real or fake + :param is_discriminator: whether the loss for is_discriminator or not. default False + :return: Tensor, GAN loss value + """ + target_val = self.real_label_val if target_is_real else self.fake_label_val + target = prediction.new_ones(prediction.size()) * target_val + + if self.loss_type == "vanilla": + return F.binary_cross_entropy_with_logits(prediction, target) + elif self.loss_type == "lsgan": + return F.mse_loss(prediction, target) + elif self.loss_type == "hinge": + if is_discriminator: + prediction = -prediction if target_is_real else prediction + loss = F.relu(1 + prediction).mean() + else: + loss = -prediction.mean() + return loss + elif self.loss_type == "wgan": + loss = -prediction.mean() if target_is_real else prediction.mean() + return loss + else: + raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.') diff --git a/main.py b/main.py new file mode 100644 index 0000000..c6435d9 --- /dev/null +++ b/main.py @@ -0,0 +1,60 @@ +from pathlib import Path +from importlib import import_module + +import torch + +import ignite +import ignite.distributed as idist +from ignite.utils import manual_seed, setup_logger + +import fire +from omegaconf import OmegaConf + + +def log_basic_info(logger, config): + logger.info(f"Train {config.name}") + logger.info(f"- PyTorch version: {torch.__version__}") + logger.info(f"- Ignite version: {ignite.__version__}") + if idist.get_world_size() > 1: + logger.info("Distributed setting:\n") + idist.show_config() + + +def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): + logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger) + log_basic_info(logger, config) + + if setup_random_seed: + manual_seed(config.misc.random_seed + idist.get_rank()) + if setup_output_dir: + output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir + config.output_dir = str(output_dir) + if idist.get_rank() == 0: + if not output_dir.exists(): + output_dir.mkdir(parents=True) + logger.info(f"mkdir -p {output_dir}") + logger.info(f"output path: {config.output_dir}") + if backup_config: + with open(output_dir / "config.yml", "w+") as f: + print(config.pretty(), file=f) + + OmegaConf.set_readonly(config, True) + + engine = import_module(f"engine.{config.engine}") + engine.run(task, config, logger) + + +def run(task, config: str, *omega_options, **kwargs): + omega_options = [str(o) for o in omega_options] + conf = OmegaConf.merge(OmegaConf.load(config), OmegaConf.from_cli(omega_options)) + backend = kwargs.get("backend", "nccl") + backup_config = kwargs.get("backup_config", False) + setup_output_dir = kwargs.get("setup_output_dir", False) + setup_random_seed = kwargs.get("setup_random_seed", False) + with idist.Parallel(backend=backend) as parallel: + parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir, + setup_random_seed=setup_random_seed) + + +if __name__ == '__main__': + fire.Fire(run) diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..28ea45c --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,2 @@ +from model.registry import MODEL +import model.residual_generator diff --git a/model/normalization.py b/model/normalization.py new file mode 100644 index 0000000..e69de29 diff --git a/model/registry.py b/model/registry.py new file mode 100644 index 0000000..6711b05 --- /dev/null +++ b/model/registry.py @@ -0,0 +1,3 @@ +from util.registry import Registry + +MODEL = Registry("model") diff --git a/model/residual_generator.py b/model/residual_generator.py new file mode 100644 index 0000000..413cdec --- /dev/null +++ b/model/residual_generator.py @@ -0,0 +1,140 @@ +import torch.nn as nn +import functools +from .registry import MODEL + + +def _select_norm_layer(norm_type): + if norm_type == "BN": + return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == "IN": + return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == "NONE": + return lambda x: nn.Identity() + else: + raise NotImplemented(f'normalization layer {norm_type} is not found') + + +@MODEL.register_module() +class ResidualBlock(nn.Module): + def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False): + super(ResidualBlock, self).__init__() + + # Only for IN, use bias since it does not have affine parameters. + use_bias = norm_type == "IN" + norm_layer = _select_norm_layer(norm_type) + models = [nn.Sequential( + nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), + norm_layer(num_channels), + nn.ReLU(inplace=True), + )] + if use_dropout: + models.append(nn.Dropout(0.5)) + models.append(nn.Sequential( + nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), + norm_layer(num_channels), + )) + self.block = nn.Sequential(*models) + + def forward(self, x): + return x + self.block(x) + + +@MODEL.register_module() +class ResGenerator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect', + norm_type="IN", use_dropout=False): + super(ResGenerator, self).__init__() + assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.' + norm_layer = _select_norm_layer(norm_type) + use_bias = norm_type == "IN" + + self.start_conv = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, + bias=use_bias), + norm_layer(num_features=base_channels), + nn.ReLU(inplace=True) + ) + + # down sampling + submodules = [] + num_down_sampling = 2 + for i in range(num_down_sampling): + multiple = 2 ** i + submodules += [ + nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, + kernel_size=3, + stride=2, padding=1, bias=use_bias), + norm_layer(num_features=base_channels * multiple * 2), + nn.ReLU(inplace=True) + ] + self.encoder = nn.Sequential(*submodules) + + res_block_channels = num_down_sampling ** 2 * base_channels + self.res_blocks = nn.ModuleList( + [ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in + range(num_blocks)]) + + # up sampling + submodules = [] + for i in range(num_down_sampling): + multiple = 2 ** (num_down_sampling - i) + submodules += [ + nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2, + padding=1, output_padding=1, bias=use_bias), + norm_layer(num_features=base_channels * multiple // 2), + nn.ReLU(inplace=True), + ] + self.decoder = nn.Sequential(*submodules) + + self.end_conv = nn.Sequential( + nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), + nn.Tanh() + ) + + def forward(self, x): + x = self.encoder(self.start_conv(x)) + for rb in self.res_blocks: + x = rb(x) + return self.end_conv(self.decoder(x)) + + +@MODEL.register_module() +class PatchDiscriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"): + super(PatchDiscriminator, self).__init__() + assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.' + norm_layer = _select_norm_layer(norm_type) + use_bias = norm_type == "IN" + + kernel_size = 4 + padding = 1 + sequence = [ + nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding), + nn.LeakyReLU(0.2, inplace=True), + ] + + # stacked intermediate layers, + # gradually increasing the number of filters + multiple_now = 1 + for n in range(1, num_conv): + multiple_prev = multiple_now + multiple_now = min(2 ** n, 8) + sequence += [ + nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size, + padding=padding, stride=2, bias=use_bias), + norm_layer(base_channels * multiple_now), + nn.LeakyReLU(0.2, inplace=True) + ] + multiple_prev = multiple_now + multiple_now = min(2 ** num_conv, 8) + sequence += [ + nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1, + padding=padding, bias=use_bias), + norm_layer(base_channels * multiple_now), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding) + ] + self.model = nn.Sequential(*sequence) + + def forward(self, x): + return self.model(x) diff --git a/model/weight_init.py b/model/weight_init.py new file mode 100644 index 0000000..8adc46f --- /dev/null +++ b/model/weight_init.py @@ -0,0 +1,70 @@ +import torch.nn as nn +from torch.nn import init + + +def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0.0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module, gain=1.0, bias=0.0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0.0, std=1.0, bias=0.0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def generation_init_weights(module, init_type='normal', init_gain=0.02): + """Default initialization of network weights for image generation. + By default, we use normal init, but xavier and kaiming might work + better for some applications. + Args: + module (nn.Module): Module to be initialized. + init_type (str): The name of an initialization method: + normal | xavier | kaiming | orthogonal. + init_gain (float): Scaling factor for normal, xavier and + orthogonal. + """ + + def init_func(m): + """Initialization function. + Args: + m (nn.Module): Module to be initialized. + """ + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + normal_init(m, 0.0, init_gain) + elif init_type == 'xavier': + xavier_init(m, gain=init_gain, distribution='normal') + elif init_type == 'kaiming': + kaiming_init(m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight, gain=init_gain) + init.constant_(m.bias.data, 0.0) + else: + raise NotImplementedError( + f"Initialization method '{init_type}' is not implemented") + elif classname.find('BatchNorm2d') != -1: + # BatchNorm Layer's weight is not a matrix; + # only normal distribution applies. + normal_init(m, 1.0, init_gain) + + module.apply(init_func) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..d075595 --- /dev/null +++ b/run.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 + +# CUDA_VISIBLE_DEVICES=$GPUS \ +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \ + main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed diff --git a/tool/img2skeleton.py b/tool/img2skeleton.py new file mode 100644 index 0000000..1637270 --- /dev/null +++ b/tool/img2skeleton.py @@ -0,0 +1,41 @@ +from pathlib import Path +import fire +import cv2 +import numpy as np +from tqdm import tqdm + + +def xDoG(img, sigma, k_sigma, p, epsilon, phi): + sigma_large = sigma * k_sigma + sharped_img = (1 + p) * cv2.GaussianBlur(img, (0, 0), sigma) - p * cv2.GaussianBlur(img, (0, 0), sigma_large) + img = np.multiply(img, sharped_img) + + t = np.zeros(img.shape) + t[img >= epsilon] = 1.0 + img_dark_indices = img < epsilon + t[img_dark_indices] = 1.0 + np.tanh(phi * (img[img_dark_indices] - epsilon)) + + return t * 256 + + +def transform_single(origin, to, anime: bool = True): + img = cv2.imread(str(origin), cv2.IMREAD_GRAYSCALE) + if anime: + r = xDoG(img / 256, 0.3, 4.5, 19, 0.01, 10 ^ 8 * 2) + else: + r = xDoG(img / 256, 0.7, 5, 20, 0.02, 10 ^ 8 * 2) + cv2.imwrite(str(to), r) + + +def transform(origin, to, anime=True): + origin = Path(origin) + to = Path(to) + if origin.is_dir() and to.is_dir(): + for f in tqdm(origin.glob("*")): + transform_single(f, to / f.name, anime) + elif origin.is_file(): + transform_single(origin, to, anime) + + +if __name__ == '__main__': + fire.Fire(transform) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/distributed.py b/util/distributed.py new file mode 100644 index 0000000..fd10615 --- /dev/null +++ b/util/distributed.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from ignite.distributed import utils as idist +from ignite.distributed.comp_models import native as idist_native +from ignite.utils import setup_logger + + +def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module: + """Helper method to adapt provided model for non-distributed and distributed configurations (supporting + all available backends from :meth:`~ignite.distributed.utils.available_backends()`). + + Internally, we perform to following: + + - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. + - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. + - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. + + Examples: + + .. code-block:: python + + import ignite.distribted as idist + + model = idist.auto_model(model) + + In addition with NVidia/Apex, it can be used in the following way: + + .. code-block:: python + + import ignite.distribted as idist + + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) + model = idist.auto_model(model) + + Args: + model (torch.nn.Module): model to adapt. + + Returns: + torch.nn.Module + + .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel + .. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel + """ + logger = setup_logger(__name__ + ".auto_model") + + # Put model's parameters to device if its parameters are not on the device + device = idist.device() + if not all([p.device == device for p in model.parameters()]): + model.to(device) + + # distributed data parallel model + if idist.get_world_size() > 1: + if idist.backend() == idist_native.NCCL: + lrank = idist.get_local_rank() + logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank)) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs) + elif idist.backend() == idist_native.GLOO: + logger.info("Apply torch DistributedDataParallel on model") + model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs) + + # not distributed but multiple GPUs reachable so data parallel model + elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type: + logger.info("Apply torch DataParallel on model") + model = torch.nn.parallel.DataParallel(model, **additional_kwargs) + + return model diff --git a/util/handler.py b/util/handler.py new file mode 100644 index 0000000..fd2291e --- /dev/null +++ b/util/handler.py @@ -0,0 +1,21 @@ +from pathlib import Path + +import torch +from ignite.engine import Engine +from ignite.handlers import Checkpoint + + +class Resumer: + def __init__(self, to_load, checkpoint_path): + self.to_load = to_load + if checkpoint_path is not None: + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise ValueError(f"Checkpoint '{checkpoint_path}' is not found") + self.checkpoint_path = checkpoint_path + + def __call__(self, engine: Engine): + if self.checkpoint_path is not None: + ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu") + Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp) + engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}") diff --git a/util/image.py b/util/image.py new file mode 100644 index 0000000..126fbe6 --- /dev/null +++ b/util/image.py @@ -0,0 +1,10 @@ +import torchvision.utils + + +def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0): + # merge image in a batch in `y` direction first. + grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range, + scale_each=scale_each, pad_value=pad_value) + for img_batch in tensors] + # merge images in `x` direction. + return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids)) diff --git a/util/registry.py b/util/registry.py new file mode 100644 index 0000000..ecb0c29 --- /dev/null +++ b/util/registry.py @@ -0,0 +1,163 @@ +import inspect +from omegaconf.dictconfig import DictConfig +from omegaconf import OmegaConf +from types import ModuleType + + +class _Registry: + def __init__(self, name): + self._name = name + + def get(self, key): + raise NotImplemented + + def keys(self): + raise NotImplemented + + def __len__(self): + len(self.keys()) + + def __contains__(self, key): + return self.get(key) is not None + + @property + def name(self): + return self._name + + def __repr__(self): + return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})" + + def build_with(self, cfg, default_args=None): + """Build a module from config dict. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + default_args (dict, optional): Default initialization arguments. + Returns: + object: The constructed object. + """ + if isinstance(cfg, DictConfig): + cfg = OmegaConf.to_container(cfg) + if isinstance(cfg, dict): + if '_type' in cfg: + args = cfg.copy() + obj_type = args.pop('_type') + elif len(cfg) == 1: + obj_type, args = list(cfg.items())[0] + else: + raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}') + elif isinstance(cfg, str): + obj_type = cfg + args = dict() + else: + raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}') + + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + if isinstance(obj_type, str): + obj_cls = self.get(obj_type) + if obj_cls is None: + raise KeyError(f'{obj_type} is not in the {self.name} registry') + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}') + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_cls(**args) + + +class ModuleRegistry(_Registry): + def __init__(self, name, module, predefined_valid_list=None): + super().__init__(name) + + assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}" + self._module = module + if predefined_valid_list is not None: + self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys()) + else: + self._valid_set = set(self._module.__dict__.keys()) + + def keys(self): + return tuple(self._valid_set) + + def get(self, key): + """Get the registry record. + Args: + key (str): The class name in string format. + Returns: + class: The corresponding class. + """ + if key not in self._valid_set: + return None + return getattr(self._module, key) + + +class Registry(_Registry): + """A registry to map strings to classes. + Args: + name (str): Registry name. + """ + + def __init__(self, name): + super().__init__(name) + self._module_dict = dict() + + def keys(self): + return tuple(self._module_dict.keys()) + + def get(self, key): + """Get the registry record. + Args: + key (str): The class name in string format. + Returns: + class: The corresponding class. + """ + return self._module_dict.get(key, None) + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError('module must be a class, ' + f'but got {type(module_class)}') + + if module_name is None: + module_name = module_class.__name__ + if not force and module_name in self._module_dict: + raise KeyError(f'{module_name} is already registered ' + f'in {self.name}') + self._module_dict[module_name] = module_class + + def register_module(self, name=None, force=False, module=None): + """Register a module. + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module( + module_class=module, module_name=name, force=force) + return module + + # raise the error ahead of time + if not (name is None or isinstance(name, str)): + raise TypeError(f'name must be a str, but got {type(name)}') + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register