base code for pytorch distributed, add cyclegan
This commit is contained in:
commit
f7843de45d
243
.gitignore
vendored
Normal file
243
.gitignore
vendored
Normal file
@ -0,0 +1,243 @@
|
||||
# Created by .ignore support plugin (hsz.mobi)
|
||||
### JetBrains template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### JupyterNotebooks template
|
||||
# gitignore template for Jupyter Notebooks
|
||||
# website: http://jupyter.org/
|
||||
|
||||
.ipynb_checkpoints
|
||||
*/.ipynb_checkpoints/*
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# Remove previous ipynb_checkpoints
|
||||
# git rm -r .ipynb_checkpoints/
|
||||
|
||||
### Python template
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
### Linux template
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Datasource local storage ignored files
|
||||
/../../../../../../:\Users\wr\Code\raycv\.idea/dataSources/
|
||||
/dataSources.local.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
22
.idea/deployment.xml
Normal file
22
.idea/deployment.xml
Normal file
@ -0,0 +1,22 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="15d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ALWAYS" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
4
.idea/misc.xml
Normal file
4
.idea/misc.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/raycv.iml" filepath="$PROJECT_DIR$/.idea/raycv.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
11
.idea/raycv.iml
Normal file
11
.idea/raycv.iml
Normal file
@ -0,0 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="pytest" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
102
configs/synthesizers/cyclegan.yml
Normal file
102
configs/synthesizers/cyclegan.yml
Normal file
@ -0,0 +1,102 @@
|
||||
name: horse2zebra
|
||||
engine: cyclegan
|
||||
result_dir: ./result
|
||||
max_iteration: 18000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
|
||||
log:
|
||||
logger:
|
||||
level: 20 # DEBUG(10) INFO(20)
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: ResGenerator
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 9
|
||||
padding_mode: reflect
|
||||
norm_type: IN
|
||||
use_dropout: False
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_distributed:
|
||||
bn_to_syncbn: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 3
|
||||
norm_type: BN
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
weight: 1.0
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
cycle:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
id:
|
||||
level: 1
|
||||
weight: 0
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
|
||||
data:
|
||||
train:
|
||||
dataloader:
|
||||
batch_size: 16
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/horse2zebra/trainA"
|
||||
root_b: "/data/i2i/horse2zebra/trainB"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
scheduler:
|
||||
start: 9000
|
||||
target_lr: 0
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/horse2zebra/testA"
|
||||
root_b: "/data/i2i/horse2zebra/testB"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
- ToTensor
|
||||
4
data/__init__.py
Normal file
4
data/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
import data.dataset
|
||||
import data.transform
|
||||
from data.registry import DATASET, TRANSFORM
|
||||
|
||||
73
data/dataset.py
Normal file
73
data/dataset.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
|
||||
import lmdb
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs):
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
**lmdb_kwargs)
|
||||
self.output_transform = output_transform
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"__len__"))
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
if self.output_transform is not None:
|
||||
sample = self.output_transform(sample)
|
||||
return sample
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pipeline(self.samples[idx])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
||||
self.A = SingleFolderDataset(root_a, pipeline)
|
||||
self.B = SingleFolderDataset(root_b, pipeline)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
4
data/registry.py
Normal file
4
data/registry.py
Normal file
@ -0,0 +1,4 @@
|
||||
from util.registry import Registry
|
||||
|
||||
DATASET = Registry("dataset")
|
||||
TRANSFORM = Registry("transform")
|
||||
34
data/transform.py
Normal file
34
data/transform.py
Normal file
@ -0,0 +1,34 @@
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from .registry import TRANSFORM
|
||||
|
||||
# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html
|
||||
_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize",
|
||||
"Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder",
|
||||
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
|
||||
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter",
|
||||
"RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective",
|
||||
"RandomErasing"]
|
||||
|
||||
for vtt in _VALID_TORCHVISION_TRANSFORMS:
|
||||
TRANSFORM.register_module(module=getattr(transforms, vtt))
|
||||
|
||||
|
||||
@TRANSFORM.register_module()
|
||||
class Load:
|
||||
def __init__(self, loader=default_loader):
|
||||
self.loader = loader
|
||||
|
||||
def __call__(self, image_path):
|
||||
return self.loader(image_path)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
def transform_pipeline(pipeline_description):
|
||||
if len(pipeline_description) == 0:
|
||||
return lambda x: x
|
||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||
return transforms.Compose(transform_list)
|
||||
0
docs/run.md
Normal file
0
docs/run.md
Normal file
0
engine/__init__.py
Normal file
0
engine/__init__.py
Normal file
274
engine/cyclegan.py
Normal file
274
engine/cyclegan.py
Normal file
@ -0,0 +1,274 @@
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import data
|
||||
from model import MODEL
|
||||
from loss.gan import GANLoss
|
||||
from util.distributed import auto_model
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import Resumer
|
||||
|
||||
|
||||
def _build_model(cfg, distributed_args=None):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
|
||||
|
||||
def _build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
||||
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
|
||||
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
|
||||
logger.debug(discriminator_a)
|
||||
logger.debug(generator_a)
|
||||
|
||||
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||
config.optimizers.generator)
|
||||
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
|
||||
loss_g = dict(
|
||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
loss_d_a = dict(
|
||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
|
||||
)
|
||||
loss_d_b = dict(
|
||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
|
||||
)
|
||||
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
|
||||
loss_d.backward()
|
||||
optimizer_d.step()
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
|
||||
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
|
||||
},
|
||||
"img": [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=10))
|
||||
def print_log(engine):
|
||||
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
|
||||
f"loss_g={engine.state.metrics['loss_g']:.3f} "
|
||||
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
|
||||
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
|
||||
|
||||
to_save = dict(
|
||||
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
||||
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
||||
|
||||
# Create an object of the profiler and attach an engine to it
|
||||
profiler = BasicTimeProfiler()
|
||||
profiler.attach(trainer)
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
@idist.one_rank_only()
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
return [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
if idist.get_rank == 0:
|
||||
ProgressBar(ncols=0).attach(tester)
|
||||
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
||||
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def mkdir(engine):
|
||||
img_output_dir = Path(config.output_dir) / "test_images"
|
||||
if not img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {img_output_dir}")
|
||||
img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors],
|
||||
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
logger.info(f"start task {task}")
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
22
environment.yml
Normal file
22
environment.yml
Normal file
@ -0,0 +1,22 @@
|
||||
name: raycv
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- numpy
|
||||
- ipython
|
||||
- tqdm
|
||||
- pyyaml
|
||||
- pytorch=1.6.*
|
||||
- torchvision
|
||||
- cudatoolkit=10.2
|
||||
- ignite
|
||||
- tensorboard
|
||||
- omegaconf
|
||||
- python-lmdb
|
||||
- fire
|
||||
# - opencv
|
||||
# - jupyterlab
|
||||
|
||||
0
loss/__init__.py
Normal file
0
loss/__init__.py
Normal file
39
loss/gan.py
Normal file
39
loss/gan.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super().__init__()
|
||||
assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"]
|
||||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
:param target_is_real: whether the target is real or fake
|
||||
:param is_discriminator: whether the loss for is_discriminator or not. default False
|
||||
:return: Tensor, GAN loss value
|
||||
"""
|
||||
target_val = self.real_label_val if target_is_real else self.fake_label_val
|
||||
target = prediction.new_ones(prediction.size()) * target_val
|
||||
|
||||
if self.loss_type == "vanilla":
|
||||
return F.binary_cross_entropy_with_logits(prediction, target)
|
||||
elif self.loss_type == "lsgan":
|
||||
return F.mse_loss(prediction, target)
|
||||
elif self.loss_type == "hinge":
|
||||
if is_discriminator:
|
||||
prediction = -prediction if target_is_real else prediction
|
||||
loss = F.relu(1 + prediction).mean()
|
||||
else:
|
||||
loss = -prediction.mean()
|
||||
return loss
|
||||
elif self.loss_type == "wgan":
|
||||
loss = -prediction.mean() if target_is_real else prediction.mean()
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
60
main.py
Normal file
60
main.py
Normal file
@ -0,0 +1,60 @@
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
from ignite.utils import manual_seed, setup_logger
|
||||
|
||||
import fire
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def log_basic_info(logger, config):
|
||||
logger.info(f"Train {config.name}")
|
||||
logger.info(f"- PyTorch version: {torch.__version__}")
|
||||
logger.info(f"- Ignite version: {ignite.__version__}")
|
||||
if idist.get_world_size() > 1:
|
||||
logger.info("Distributed setting:\n")
|
||||
idist.show_config()
|
||||
|
||||
|
||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
|
||||
log_basic_info(logger, config)
|
||||
|
||||
if setup_random_seed:
|
||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||
if setup_output_dir:
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||
config.output_dir = str(output_dir)
|
||||
if idist.get_rank() == 0:
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
logger.info(f"mkdir -p {output_dir}")
|
||||
logger.info(f"output path: {config.output_dir}")
|
||||
if backup_config:
|
||||
with open(output_dir / "config.yml", "w+") as f:
|
||||
print(config.pretty(), file=f)
|
||||
|
||||
OmegaConf.set_readonly(config, True)
|
||||
|
||||
engine = import_module(f"engine.{config.engine}")
|
||||
engine.run(task, config, logger)
|
||||
|
||||
|
||||
def run(task, config: str, *omega_options, **kwargs):
|
||||
omega_options = [str(o) for o in omega_options]
|
||||
conf = OmegaConf.merge(OmegaConf.load(config), OmegaConf.from_cli(omega_options))
|
||||
backend = kwargs.get("backend", "nccl")
|
||||
backup_config = kwargs.get("backup_config", False)
|
||||
setup_output_dir = kwargs.get("setup_output_dir", False)
|
||||
setup_random_seed = kwargs.get("setup_random_seed", False)
|
||||
with idist.Parallel(backend=backend) as parallel:
|
||||
parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir,
|
||||
setup_random_seed=setup_random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(run)
|
||||
2
model/__init__.py
Normal file
2
model/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from model.registry import MODEL
|
||||
import model.residual_generator
|
||||
0
model/normalization.py
Normal file
0
model/normalization.py
Normal file
3
model/registry.py
Normal file
3
model/registry.py
Normal file
@ -0,0 +1,3 @@
|
||||
from util.registry import Registry
|
||||
|
||||
MODEL = Registry("model")
|
||||
140
model/residual_generator.py
Normal file
140
model/residual_generator.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
from .registry import MODEL
|
||||
|
||||
|
||||
def _select_norm_layer(norm_type):
|
||||
if norm_type == "BN":
|
||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == "IN":
|
||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == "NONE":
|
||||
return lambda x: nn.Identity()
|
||||
else:
|
||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN", use_dropout=False):
|
||||
super(ResGenerator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3,
|
||||
stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
for rb in self.res_blocks:
|
||||
x = rb(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
kernel_size = 4
|
||||
padding = 1
|
||||
sequence = [
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size,
|
||||
padding=padding, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** num_conv, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1,
|
||||
padding=padding, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding)
|
||||
]
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
70
model/weight_init.py
Normal file
70
model/weight_init.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0.0, distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if distribution == 'uniform':
|
||||
nn.init.kaiming_uniform_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
else:
|
||||
nn.init.kaiming_normal_(
|
||||
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def xavier_init(module, gain=1.0, bias=0.0, distribution='normal'):
|
||||
assert distribution in ['uniform', 'normal']
|
||||
if distribution == 'uniform':
|
||||
nn.init.xavier_uniform_(module.weight, gain=gain)
|
||||
else:
|
||||
nn.init.xavier_normal_(module.weight, gain=gain)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def normal_init(module, mean=0.0, std=1.0, bias=0.0):
|
||||
nn.init.normal_(module.weight, mean, std)
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
||||
"""Default initialization of network weights for image generation.
|
||||
By default, we use normal init, but xavier and kaiming might work
|
||||
better for some applications.
|
||||
Args:
|
||||
module (nn.Module): Module to be initialized.
|
||||
init_type (str): The name of an initialization method:
|
||||
normal | xavier | kaiming | orthogonal.
|
||||
init_gain (float): Scaling factor for normal, xavier and
|
||||
orthogonal.
|
||||
"""
|
||||
|
||||
def init_func(m):
|
||||
"""Initialization function.
|
||||
Args:
|
||||
m (nn.Module): Module to be initialized.
|
||||
"""
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and (classname.find('Conv') != -1
|
||||
or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
normal_init(m, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
xavier_init(m, gain=init_gain, distribution='normal')
|
||||
elif init_type == 'kaiming':
|
||||
kaiming_init(m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight, gain=init_gain)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Initialization method '{init_type}' is not implemented")
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
# BatchNorm Layer's weight is not a matrix;
|
||||
# only normal distribution applies.
|
||||
normal_init(m, 1.0, init_gain)
|
||||
|
||||
module.apply(init_func)
|
||||
8
run.sh
Normal file
8
run.sh
Normal file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPUS" \
|
||||
main.py train "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
||||
41
tool/img2skeleton.py
Normal file
41
tool/img2skeleton.py
Normal file
@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
import fire
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def xDoG(img, sigma, k_sigma, p, epsilon, phi):
|
||||
sigma_large = sigma * k_sigma
|
||||
sharped_img = (1 + p) * cv2.GaussianBlur(img, (0, 0), sigma) - p * cv2.GaussianBlur(img, (0, 0), sigma_large)
|
||||
img = np.multiply(img, sharped_img)
|
||||
|
||||
t = np.zeros(img.shape)
|
||||
t[img >= epsilon] = 1.0
|
||||
img_dark_indices = img < epsilon
|
||||
t[img_dark_indices] = 1.0 + np.tanh(phi * (img[img_dark_indices] - epsilon))
|
||||
|
||||
return t * 256
|
||||
|
||||
|
||||
def transform_single(origin, to, anime: bool = True):
|
||||
img = cv2.imread(str(origin), cv2.IMREAD_GRAYSCALE)
|
||||
if anime:
|
||||
r = xDoG(img / 256, 0.3, 4.5, 19, 0.01, 10 ^ 8 * 2)
|
||||
else:
|
||||
r = xDoG(img / 256, 0.7, 5, 20, 0.02, 10 ^ 8 * 2)
|
||||
cv2.imwrite(str(to), r)
|
||||
|
||||
|
||||
def transform(origin, to, anime=True):
|
||||
origin = Path(origin)
|
||||
to = Path(to)
|
||||
if origin.is_dir() and to.is_dir():
|
||||
for f in tqdm(origin.glob("*")):
|
||||
transform_single(f, to / f.name, anime)
|
||||
elif origin.is_file():
|
||||
transform_single(origin, to, anime)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(transform)
|
||||
0
util/__init__.py
Normal file
0
util/__init__.py
Normal file
66
util/distributed.py
Normal file
66
util/distributed.py
Normal file
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.distributed import utils as idist
|
||||
from ignite.distributed.comp_models import native as idist_native
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
|
||||
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
||||
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
||||
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
||||
|
||||
Internally, we perform to following:
|
||||
|
||||
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
||||
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
||||
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model = idist.auto_model(model)
|
||||
|
||||
In addition with NVidia/Apex, it can be used in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
||||
model = idist.auto_model(model)
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to adapt.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module
|
||||
|
||||
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
||||
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
"""
|
||||
logger = setup_logger(__name__ + ".auto_model")
|
||||
|
||||
# Put model's parameters to device if its parameters are not on the device
|
||||
device = idist.device()
|
||||
if not all([p.device == device for p in model.parameters()]):
|
||||
model.to(device)
|
||||
|
||||
# distributed data parallel model
|
||||
if idist.get_world_size() > 1:
|
||||
if idist.backend() == idist_native.NCCL:
|
||||
lrank = idist.get_local_rank()
|
||||
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
||||
elif idist.backend() == idist_native.GLOO:
|
||||
logger.info("Apply torch DistributedDataParallel on model")
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
||||
|
||||
# not distributed but multiple GPUs reachable so data parallel model
|
||||
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
||||
logger.info("Apply torch DataParallel on model")
|
||||
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
||||
|
||||
return model
|
||||
21
util/handler.py
Normal file
21
util/handler.py
Normal file
@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine
|
||||
from ignite.handlers import Checkpoint
|
||||
|
||||
|
||||
class Resumer:
|
||||
def __init__(self, to_load, checkpoint_path):
|
||||
self.to_load = to_load
|
||||
if checkpoint_path is not None:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
self.checkpoint_path = checkpoint_path
|
||||
|
||||
def __call__(self, engine: Engine):
|
||||
if self.checkpoint_path is not None:
|
||||
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
|
||||
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
|
||||
10
util/image.py
Normal file
10
util/image.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torchvision.utils
|
||||
|
||||
|
||||
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
||||
# merge image in a batch in `y` direction first.
|
||||
grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range,
|
||||
scale_each=scale_each, pad_value=pad_value)
|
||||
for img_batch in tensors]
|
||||
# merge images in `x` direction.
|
||||
return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))
|
||||
163
util/registry.py
Normal file
163
util/registry.py
Normal file
@ -0,0 +1,163 @@
|
||||
import inspect
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from omegaconf import OmegaConf
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
class _Registry:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
def get(self, key):
|
||||
raise NotImplemented
|
||||
|
||||
def keys(self):
|
||||
raise NotImplemented
|
||||
|
||||
def __len__(self):
|
||||
len(self.keys())
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})"
|
||||
|
||||
def build_with(self, cfg, default_args=None):
|
||||
"""Build a module from config dict.
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if isinstance(cfg, DictConfig):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
if isinstance(cfg, dict):
|
||||
if '_type' in cfg:
|
||||
args = cfg.copy()
|
||||
obj_type = args.pop('_type')
|
||||
elif len(cfg) == 1:
|
||||
obj_type, args = list(cfg.items())[0]
|
||||
else:
|
||||
raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}')
|
||||
elif isinstance(cfg, str):
|
||||
obj_type = cfg
|
||||
args = dict()
|
||||
else:
|
||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = self.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f'{obj_type} is not in the {self.name} registry')
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_cls(**args)
|
||||
|
||||
|
||||
class ModuleRegistry(_Registry):
|
||||
def __init__(self, name, module, predefined_valid_list=None):
|
||||
super().__init__(name)
|
||||
|
||||
assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}"
|
||||
self._module = module
|
||||
if predefined_valid_list is not None:
|
||||
self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys())
|
||||
else:
|
||||
self._valid_set = set(self._module.__dict__.keys())
|
||||
|
||||
def keys(self):
|
||||
return tuple(self._valid_set)
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
if key not in self._valid_set:
|
||||
return None
|
||||
return getattr(self._module, key)
|
||||
|
||||
|
||||
class Registry(_Registry):
|
||||
"""A registry to map strings to classes.
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self._module_dict = dict()
|
||||
|
||||
def keys(self):
|
||||
return tuple(self._module_dict.keys())
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
return self._module_dict.get(key, None)
|
||||
|
||||
def _register_module(self, module_class, module_name=None, force=False):
|
||||
if not inspect.isclass(module_class):
|
||||
raise TypeError('module must be a class, '
|
||||
f'but got {type(module_class)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module_class.__name__
|
||||
if not force and module_name in self._module_dict:
|
||||
raise KeyError(f'{module_name} is already registered '
|
||||
f'in {self.name}')
|
||||
self._module_dict[module_name] = module_class
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
"""Register a module.
|
||||
A record will be added to `self._module_dict`, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
Args:
|
||||
name (str | None): The module name to be registered. If not
|
||||
specified, the class name will be used.
|
||||
force (bool, optional): Whether to override an existing class with
|
||||
the same name. Default: False.
|
||||
module (type): Module class to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(
|
||||
module_class=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)):
|
||||
raise TypeError(f'name must be a str, but got {type(name)}')
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(cls):
|
||||
self._register_module(module_class=cls, module_name=name, force=force)
|
||||
return cls
|
||||
|
||||
return _register
|
||||
Loading…
Reference in New Issue
Block a user