almost 0.1

This commit is contained in:
Ray Wong 2020-09-06 10:34:52 +08:00
parent e3c760d0c5
commit ab545843bf
15 changed files with 308 additions and 680 deletions

View File

@ -1,51 +0,0 @@
name: cross-domain-1
engine: crossdomain
result_dir: ./result
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 1004
checkpoints:
interval: 2000
log:
logger:
level: 20 # DEBUG(10) INFO(20)
model:
_type: resnet10
baseline:
plusplus: False
optimizers:
_type: Adam
data:
dataloader:
batch_size: 1200
shuffle: True
num_workers: 16
pin_memory: True
drop_last: True
dataset:
train:
path: /data/few-shot/mini_imagenet_full_size/train
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
pipeline:
- Load
- RandomResizedCrop:
size: [224, 224]
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

View File

@ -1,40 +1,34 @@
name: horse2zebra name: horse2zebra-CyCleGAN
engine: cyclegan engine: CyCleGAN
result_dir: ./result result_dir: ./result
max_iteration: 16600 max_pairs: 266800
distributed:
model:
# broadcast_buffers: False
misc: misc:
random_seed: 324 random_seed: 324
checkpoints: handler:
interval: 2000 clear_cuda_cache: False
set_epoch_for_dist_sampler: True
log: checkpoint:
logger: epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
level: 20 # DEBUG(10) INFO(20) n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
model: model:
generator: generator:
_type: ResGenerator _type: CyCle-Generator
in_channels: 3 in_channels: 3
out_channels: 3 out_channels: 3
base_channels: 64 base_channels: 64
num_blocks: 9 num_blocks: 9
padding_mode: reflect padding_mode: reflect
norm_type: IN norm_type: IN
use_dropout: False
discriminator: discriminator:
_type: PatchDiscriminator _type: PatchDiscriminator
# _distributed:
# bn_to_syncbn: False
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
num_conv: 3
norm_type: IN
loss: loss:
gan: gan:
@ -53,19 +47,22 @@ optimizers:
generator: generator:
_type: Adam _type: Adam
lr: 2e-4 lr: 2e-4
betas: [0.5, 0.999] betas: [ 0.5, 0.999 ]
discriminator: discriminator:
_type: Adam _type: Adam
lr: 2e-4 lr: 2e-4
betas: [0.5, 0.999] betas: [ 0.5, 0.999 ]
data: data:
train: train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 16 batch_size: 6
shuffle: True shuffle: True
num_workers: 4 num_workers: 2
pin_memory: True pin_memory: True
drop_last: True drop_last: True
dataset: dataset:
@ -76,14 +73,14 @@ data:
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [286, 286] size: [ 286, 286 ]
- RandomCrop: - RandomCrop:
size: [256, 256] size: [ 256, 256 ]
- RandomHorizontalFlip - RandomHorizontalFlip
- ToTensor - ToTensor
scheduler: - Normalize:
start: 8300 mean: [ 0.5, 0.5, 0.5 ]
target_lr: 0 std: [ 0.5, 0.5, 0.5 ]
test: test:
dataloader: dataloader:
batch_size: 4 batch_size: 4
@ -99,5 +96,8 @@ data:
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [256, 256] size: [ 256, 256 ]
- ToTensor - ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -1,7 +1,7 @@
name: TAFG name: TAFG
engine: TAFG engine: TAFG
result_dir: ./result result_dir: ./result
max_pairs: 1000000 max_pairs: 1500000
handler: handler:
clear_cuda_cache: True clear_cuda_cache: True
@ -28,7 +28,7 @@ model:
_type: MultiScaleDiscriminator _type: MultiScaleDiscriminator
num_scale: 2 num_scale: 2
discriminator_cfg: discriminator_cfg:
_type: pix2pixHD-PatchDiscriminator _type: PatchDiscriminator
in_channels: 3 in_channels: 3
base_channels: 64 base_channels: 64
use_spectral: True use_spectral: True

101
engine/CyCleGAN.py Normal file
View File

@ -0,0 +1,101 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.gan import GANLoss
from model.GAN.base import GANImageBuffer
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in ["a2b", "b2a"]:
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[phase[-1]](generated[phase]), True)
if self.config.loss.id.weight > 0:
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
is_discriminator=True) +
self.gan_loss(self.discriminators[phase](batch[phase]), True,
is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
)
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)

View File

@ -5,6 +5,9 @@ from omegaconf import OmegaConf
import torch import torch
import torch.nn as nn import torch.nn as nn
import ignite.distributed as idist import ignite.distributed as idist
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss from loss.I2I.perceptual_loss import PerceptualLoss
@ -49,7 +52,7 @@ class TAFGEngineKernel(EngineKernel):
return generators, discriminators return generators, discriminators
def setup_before_d(self): def setup_after_g(self):
for discriminator in self.discriminators.values(): for discriminator in self.discriminators.values():
discriminator.requires_grad_(True) discriminator.requires_grad_(True)
@ -89,7 +92,7 @@ class TAFGEngineKernel(EngineKernel):
for j in range(num_intermediate_outputs): for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( # loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]), # self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"]) # self.generators["main"].module.style_encoders["b"](generated["b"])
@ -122,6 +125,12 @@ class TAFGEngineKernel(EngineKernel):
generated["b"].detach()] generated["b"].detach()]
) )
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
def run(task, config, _): def run(task, config, _):
kernel = TAFGEngineKernel(config) kernel = TAFGEngineKernel(config)

View File

@ -1,5 +1,3 @@
from itertools import chain
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch import torch
@ -7,10 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import ignite.distributed as idist import ignite.distributed as idist
from model.weight_init import generation_init_weights
from loss.gan import GANLoss from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper from model.GAN.UGATIT import RhoClipper
from model.GAN.residual_generator import GANImageBuffer from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model from engine.util.build import build_model
@ -36,6 +33,7 @@ class UGATITEngineKernel(EngineKernel):
self.rho_clipper = RhoClipper(0, 1) self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()} self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict): def build_models(self) -> (dict, dict):
generators = dict( generators = dict(
@ -51,12 +49,9 @@ class UGATITEngineKernel(EngineKernel):
self.logger.debug(discriminators["ga"]) self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"]) self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators return generators, discriminators
def setup_before_d(self): def setup_after_g(self):
for generator in self.generators.values(): for generator in self.generators.values():
generator.apply(self.rho_clipper) generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values(): for discriminator in self.discriminators.values():
@ -101,8 +96,7 @@ class UGATITEngineKernel(EngineKernel):
loss = dict() loss = dict()
for phase in "ab": for phase in "ab":
for level in "gl": for level in "gl":
generated_image = self.image_buffers[level + phase].query( generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
generated["images"]["a2b" if phase == "b" else "b2a"])
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image) pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase]) pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss( loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(

View File

@ -1,23 +1,21 @@
from itertools import chain
import logging import logging
from itertools import chain
from pathlib import Path from pathlib import Path
import ignite.distributed as idist
import torch import torch
import torchvision import torchvision
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
import ignite.distributed as idist
from ignite.engine import Events, Engine from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor from ignite.utils import convert_tensor
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear from math import ceil
from omegaconf import read_write, OmegaConf from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
import data import data
from engine.util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from util.image import make_2d_grid
def build_lr_schedulers(optimizers, config): def build_lr_schedulers(optimizers, config):
@ -59,6 +57,7 @@ class EngineKernel(object):
self.config = config self.config = config
self.logger = logging.getLogger(config.name) self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models() self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
def build_models(self) -> (dict, dict): def build_models(self) -> (dict, dict):
raise NotImplemented raise NotImplemented
@ -69,7 +68,7 @@ class EngineKernel(object):
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators}) to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save return to_save
def setup_before_d(self): def setup_after_g(self):
raise NotImplemented raise NotImplemented
def setup_before_g(self): def setup_before_g(self):
@ -93,6 +92,9 @@ class EngineKernel(object):
""" """
raise NotImplemented raise NotImplemented
def change_engine(self, config, engine: Engine):
pass
def get_trainer(config, kernel: EngineKernel): def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name) logger = logging.getLogger(config.name)
@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel):
lr_schedulers = build_lr_schedulers(optimizers, config) lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}") logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def train_generators(batch, generated):
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
kernel.setup_after_g()
return loss_g
def train_discriminators(batch, generated):
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return loss_d
def _step(engine, batch): def _step(engine, batch):
batch = convert_tensor(batch, idist.device()) batch = convert_tensor(batch, idist.device())
generated = kernel.forward(batch) generated = kernel.forward(batch)
kernel.setup_before_g() if kernel.train_generator_first:
optimizers["g"].zero_grad() loss_g = train_generators(batch, generated)
loss_g = kernel.criterion_generators(batch, generated) loss_d = train_discriminators(batch, generated)
sum(loss_g.values()).backward() else:
optimizers["g"].step() loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
kernel.setup_before_d() if engine.state.iteration % iteration_per_image == 0:
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
if engine.state.iteration % image_per_iteration == 0:
return { return {
"loss": dict(g=loss_g, d=loss_d), "loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated) "img": kernel.intermediate_images(batch, generated)
@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel):
for lr_shd in lr_schedulers.values(): for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer) to_save = dict(trainer=trainer)
@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel):
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item") tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None: if tensorboard_handler is not None:
basic_image_event = Events.ITERATION_COMPLETED( basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration) every=iteration_per_image)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(basic_image_event) @trainer.on(basic_image_event)
@ -227,7 +242,7 @@ def run_kernel(task, config, kernel):
logger = logging.getLogger(config.name) logger = logging.getLogger(config.name)
with read_write(config): with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size() real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1 config.max_iteration = ceil(config.max_pairs / real_batch_size)
if task == "train": if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset) train_dataset = data.DATASET.build_with(config.data.train.dataset)
@ -243,7 +258,7 @@ def run_kernel(task, config, kernel):
test_dataset = data.DATASET.build_with(config.data.test.dataset) test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset trainer.state.test_dataset = test_dataset
try: try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception: except Exception:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())

View File

@ -1,268 +0,0 @@
import itertools
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.info(discriminator_a)
logger.info(generator_a)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(0, config.optimizers.generator.lr),
(100, config.optimizers.generator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(100, config.optimizers.discriminator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
if config.loss.id.weight > 0:
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
},
"img": [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
with torch.no_grad():
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
return [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
tester = Engine(_step)
tester.logger = logger
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
def mkdir(engine):
img_output_dir = Path(config.output_dir) / "test_images"
if not img_output_dir.exists():
engine.logger.info(f"mkdir {img_output_dir}")
img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
torchvision.utils.save_image([img[i] for img in img_tensors],
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -17,6 +17,6 @@ dependencies:
- omegaconf - omegaconf
- python-lmdb - python-lmdb
- fire - fire
# - opencv - opencv
# - jupyterlab # - jupyterlab

62
model/GAN/CycleGAN.py Normal file
View File

@ -0,0 +1,62 @@
import torch.nn as nn
from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
@MODEL.register_module("CyCle-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(Generator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))

View File

@ -45,7 +45,6 @@ class Generator(nn.Module):
# Down-Sampling Bottleneck # Down-Sampling Bottleneck
mult = 2 ** n_down_sampling mult = 2 ** n_down_sampling
for i in range(num_blocks): for i in range(num_blocks):
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)] down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder) self.down_encoder = nn.Sequential(*down_encoder)

View File

@ -1,13 +1,68 @@
import math import math
import torch
import torch.nn as nn import torch.nn as nn
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
from model import MODEL from model import MODEL
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# based SPADE or pix2pixHD Discriminator # based SPADE or pix2pixHD Discriminator
@MODEL.register_module("pix2pixHD-PatchDiscriminator") @MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module): class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False): need_intermediate_feature=False):

View File

@ -1,182 +0,0 @@
import torch
import torch.nn as nn
from model.registry import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
models = [nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
nn.ReLU(inplace=True),
)]
if use_dropout:
models.append(nn.Dropout(0.5))
models.append(nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
))
self.block = nn.Sequential(*models)
def forward(self, x):
return x + self.block(x)
@MODEL.register_module()
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(ResGenerator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))
@MODEL.register_module()
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
kernel_size = 4
padding = 1
sequence = [
nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, inplace=True),
]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 8)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size,
padding=padding, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
multiple_prev = multiple_now
multiple_now = min(2 ** num_conv, 8)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1,
padding=padding, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding)
]
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@ -1,7 +1,6 @@
from model.registry import MODEL from model.registry import MODEL
import model.GAN.residual_generator import model.GAN.CycleGAN
import model.GAN.TAFG import model.GAN.TAFG
import model.GAN.UGATIT import model.GAN.UGATIT
import model.fewshot
import model.GAN.wrapper import model.GAN.wrapper
import model.GAN.base import model.GAN.base

View File

@ -1,105 +0,0 @@
import math
import torch.nn as nn
from .registry import MODEL
# --- gaussian initialize ---
def init_layer(l):
# Initialization using fan-in
if isinstance(l, nn.Conv2d):
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
elif isinstance(l, nn.BatchNorm2d):
l.weight.data.fill_(1)
l.bias.data.fill_(0)
elif isinstance(l, nn.Linear):
l.bias.data.fill_(0)
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class SimpleBlock(nn.Module):
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
super(SimpleBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
o = self.block(x)
return self.relu(o + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
super().__init__()
assert len(layers) == 4, 'Can have only four stages'
self.inplanes = 64
self.start = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
trunk = []
in_channels = self.inplanes
for i in range(4):
for j in range(layers[i]):
half_res = i >= 1 and j == 0
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
in_channels = dims[i]
if flatten:
trunk.append(nn.AvgPool2d(7))
trunk.append(Flatten())
if num_classes is not None:
if classifier_type == "linear":
trunk.append(nn.Linear(in_channels, num_classes))
elif classifier_type == "distlinear":
pass
else:
raise ValueError(f"invalid classifier_type:{classifier_type}")
self.trunk = nn.Sequential(*trunk)
self.apply(init_layer)
def forward(self, x):
return self.trunk(self.start(x))
@MODEL.register_module()
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)