almost 0.1
This commit is contained in:
parent
e3c760d0c5
commit
ab545843bf
@ -1,51 +0,0 @@
|
||||
name: cross-domain-1
|
||||
engine: crossdomain
|
||||
result_dir: ./result
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
|
||||
log:
|
||||
logger:
|
||||
level: 20 # DEBUG(10) INFO(20)
|
||||
|
||||
model:
|
||||
_type: resnet10
|
||||
|
||||
baseline:
|
||||
plusplus: False
|
||||
optimizers:
|
||||
_type: Adam
|
||||
data:
|
||||
dataloader:
|
||||
batch_size: 1200
|
||||
shuffle: True
|
||||
num_workers: 16
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
train:
|
||||
path: /data/few-shot/mini_imagenet_full_size/train
|
||||
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
|
||||
pipeline:
|
||||
- Load
|
||||
- RandomResizedCrop:
|
||||
size: [224, 224]
|
||||
- ColorJitter:
|
||||
brightness: 0.4
|
||||
contrast: 0.4
|
||||
saturation: 0.4
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
@ -1,40 +1,34 @@
|
||||
name: horse2zebra
|
||||
engine: cyclegan
|
||||
name: horse2zebra-CyCleGAN
|
||||
engine: CyCleGAN
|
||||
result_dir: ./result
|
||||
max_iteration: 16600
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
max_pairs: 266800
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
|
||||
log:
|
||||
logger:
|
||||
level: 20 # DEBUG(10) INFO(20)
|
||||
handler:
|
||||
clear_cuda_cache: False
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: ResGenerator
|
||||
_type: CyCle-Generator
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 9
|
||||
padding_mode: reflect
|
||||
norm_type: IN
|
||||
use_dropout: False
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
# _distributed:
|
||||
# bn_to_syncbn: False
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 3
|
||||
norm_type: IN
|
||||
|
||||
loss:
|
||||
gan:
|
||||
@ -53,19 +47,22 @@ optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 16
|
||||
batch_size: 6
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
@ -76,14 +73,14 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
scheduler:
|
||||
start: 8300
|
||||
target_lr: 0
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
@ -99,5 +96,8 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@ -1,7 +1,7 @@
|
||||
name: TAFG
|
||||
engine: TAFG
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
max_pairs: 1500000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
@ -28,7 +28,7 @@ model:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: pix2pixHD-PatchDiscriminator
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
|
||||
101
engine/CyCleGAN.py
Normal file
101
engine/CyCleGAN.py
Normal file
@ -0,0 +1,101 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
b2a=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
a=build_model(self.config.model.discriminator),
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["a"])
|
||||
self.logger.debug(generators["a2b"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"] = self.generators["a2b"](batch["a"])
|
||||
images["b2a"] = self.generators["b2a"](batch["b"])
|
||||
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||
if self.config.loss.id.weight > 0:
|
||||
images["a2a"] = self.generators["b2a"](batch["a"])
|
||||
images["b2b"] = self.generators["a2b"](batch["b"])
|
||||
return images
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in ["a2b", "b2a"]:
|
||||
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
|
||||
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
|
||||
self.discriminators[phase[-1]](generated[phase]), True)
|
||||
if self.config.loss.id.weight > 0:
|
||||
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
|
||||
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
|
||||
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
|
||||
is_discriminator=True) +
|
||||
self.gan_loss(self.discriminators[phase](batch[phase]), True,
|
||||
is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
return dict(
|
||||
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
|
||||
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAFGEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
@ -5,6 +5,9 @@ from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from model.weight_init import generation_init_weights
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
@ -49,7 +52,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_before_d(self):
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
@ -89,7 +92,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
||||
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
@ -122,6 +125,12 @@ class TAFGEngineKernel(EngineKernel):
|
||||
generated["b"].detach()]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
|
||||
def change_config(engine):
|
||||
with read_write(config):
|
||||
config.loss.perceptual.weight = 5
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAFGEngineKernel(config)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from itertools import chain
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import torch
|
||||
@ -7,10 +5,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
|
||||
from model.weight_init import generation_init_weights
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from util.image import attention_colored_map
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
@ -36,6 +33,7 @@ class UGATITEngineKernel(EngineKernel):
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
@ -51,12 +49,9 @@ class UGATITEngineKernel(EngineKernel):
|
||||
self.logger.debug(discriminators["ga"])
|
||||
self.logger.debug(generators["a2b"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_before_d(self):
|
||||
def setup_after_g(self):
|
||||
for generator in self.generators.values():
|
||||
generator.apply(self.rho_clipper)
|
||||
for discriminator in self.discriminators.values():
|
||||
@ -101,8 +96,7 @@ class UGATITEngineKernel(EngineKernel):
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
for level in "gl":
|
||||
generated_image = self.image_buffers[level + phase].query(
|
||||
generated["images"]["a2b" if phase == "b" else "b2a"])
|
||||
generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
|
||||
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
|
||||
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
|
||||
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
|
||||
|
||||
@ -1,23 +1,21 @@
|
||||
from itertools import chain
|
||||
import logging
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from math import ceil
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from util.image import make_2d_grid
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from engine.util.build import build_optimizer
|
||||
|
||||
import data
|
||||
from engine.util.build import build_optimizer
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.image import make_2d_grid
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
@ -59,6 +57,7 @@ class EngineKernel(object):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
self.train_generator_first = True
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplemented
|
||||
@ -69,7 +68,7 @@ class EngineKernel(object):
|
||||
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
|
||||
return to_save
|
||||
|
||||
def setup_before_d(self):
|
||||
def setup_after_g(self):
|
||||
raise NotImplemented
|
||||
|
||||
def setup_before_g(self):
|
||||
@ -93,6 +92,9 @@ class EngineKernel(object):
|
||||
"""
|
||||
raise NotImplemented
|
||||
|
||||
def change_engine(self, config, engine: Engine):
|
||||
pass
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
|
||||
def train_generators(batch, generated):
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
kernel.setup_after_g()
|
||||
return loss_g
|
||||
|
||||
def train_discriminators(batch, generated):
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
return loss_d
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
if kernel.train_generator_first:
|
||||
loss_g = train_generators(batch, generated)
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
else:
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
loss_g = train_generators(batch, generated)
|
||||
|
||||
kernel.setup_before_d()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
if engine.state.iteration % image_per_iteration == 0:
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
to_save = dict(trainer=trainer)
|
||||
@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
|
||||
if tensorboard_handler is not None:
|
||||
basic_image_event = Events.ITERATION_COMPLETED(
|
||||
every=image_per_iteration)
|
||||
every=iteration_per_image)
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
@trainer.on(basic_image_event)
|
||||
@ -227,7 +242,7 @@ def run_kernel(task, config, kernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
config.max_iteration = config.max_pairs // real_batch_size + 1
|
||||
config.max_iteration = ceil(config.max_pairs / real_batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
@ -243,7 +258,7 @@ def run_kernel(task, config, kernel):
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
@ -1,268 +0,0 @@
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
|
||||
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
||||
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
||||
generation_init_weights(m)
|
||||
logger.info(discriminator_a)
|
||||
logger.info(generator_a)
|
||||
|
||||
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||
config.optimizers.generator)
|
||||
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(100, config.optimizers.generator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(100, config.optimizers.discriminator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
discriminator_a.requires_grad_(False)
|
||||
discriminator_b.requires_grad_(False)
|
||||
loss_g = dict(
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
if config.loss.id.weight > 0:
|
||||
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
discriminator_a.requires_grad_(True)
|
||||
discriminator_b.requires_grad_(True)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d_a = dict(
|
||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
||||
)
|
||||
loss_d_b = dict(
|
||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
||||
)
|
||||
(sum(loss_d_a.values()) * 0.5).backward()
|
||||
(sum(loss_d_b.values()) * 0.5).backward()
|
||||
optimizer_d.step()
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
|
||||
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
|
||||
},
|
||||
"img": [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
||||
|
||||
to_save = dict(
|
||||
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
||||
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
return [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
if idist.get_rank == 0:
|
||||
ProgressBar(ncols=0).attach(tester)
|
||||
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
||||
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def mkdir(engine):
|
||||
img_output_dir = Path(config.output_dir) / "test_images"
|
||||
if not img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {img_output_dir}")
|
||||
img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors],
|
||||
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
@ -17,6 +17,6 @@ dependencies:
|
||||
- omegaconf
|
||||
- python-lmdb
|
||||
- fire
|
||||
# - opencv
|
||||
- opencv
|
||||
# - jupyterlab
|
||||
|
||||
|
||||
62
model/GAN/CycleGAN.py
Normal file
62
model/GAN/CycleGAN.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .base import ResidualBlock
|
||||
|
||||
|
||||
@MODEL.register_module("CyCle-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN"):
|
||||
super(Generator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
@ -45,7 +45,6 @@ class Generator(nn.Module):
|
||||
# Down-Sampling Bottleneck
|
||||
mult = 2 ** n_down_sampling
|
||||
for i in range(num_blocks):
|
||||
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
|
||||
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
||||
self.down_encoder = nn.Sequential(*down_encoder)
|
||||
|
||||
|
||||
@ -1,13 +1,68 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model import MODEL
|
||||
|
||||
|
||||
class GANImageBuffer(object):
|
||||
"""This class implements an image buffer that stores previously
|
||||
generated images.
|
||||
This buffer allows us to update the discriminator using a history of
|
||||
generated images rather than the ones produced by the latest generator
|
||||
to reduce model oscillation.
|
||||
Args:
|
||||
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||
no buffer will be created.
|
||||
buffer_ratio (float): The chance / possibility to use the images
|
||||
previously stored in the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||
self.buffer_size = buffer_size
|
||||
# create an empty buffer
|
||||
if self.buffer_size > 0:
|
||||
self.img_num = 0
|
||||
self.image_buffer = []
|
||||
self.buffer_ratio = buffer_ratio
|
||||
|
||||
def query(self, images):
|
||||
"""Query current image batch using a history of generated images.
|
||||
Args:
|
||||
images (Tensor): Current image batch without history information.
|
||||
"""
|
||||
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
# if the buffer is not full, keep inserting current images
|
||||
if self.img_num < self.buffer_size:
|
||||
self.img_num = self.img_num + 1
|
||||
self.image_buffer.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||
# by self.buffer_ratio, the buffer will return a previously
|
||||
# stored image, and insert the current image into the buffer
|
||||
if use_buffer:
|
||||
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||
image_tmp = self.image_buffer[random_id].clone()
|
||||
self.image_buffer[random_id] = image
|
||||
return_images.append(image_tmp)
|
||||
# by (1 - self.buffer_ratio), the buffer will return the
|
||||
# current image
|
||||
else:
|
||||
return_images.append(image)
|
||||
# collect all the images and return
|
||||
return_images = torch.cat(return_images, 0)
|
||||
return return_images
|
||||
|
||||
|
||||
# based SPADE or pix2pixHD Discriminator
|
||||
@MODEL.register_module("pix2pixHD-PatchDiscriminator")
|
||||
@MODEL.register_module("PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||
need_intermediate_feature=False):
|
||||
|
||||
@ -1,182 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.registry import MODEL
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class GANImageBuffer(object):
|
||||
"""This class implements an image buffer that stores previously
|
||||
generated images.
|
||||
This buffer allows us to update the discriminator using a history of
|
||||
generated images rather than the ones produced by the latest generator
|
||||
to reduce model oscillation.
|
||||
Args:
|
||||
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||
no buffer will be created.
|
||||
buffer_ratio (float): The chance / possibility to use the images
|
||||
previously stored in the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||
self.buffer_size = buffer_size
|
||||
# create an empty buffer
|
||||
if self.buffer_size > 0:
|
||||
self.img_num = 0
|
||||
self.image_buffer = []
|
||||
self.buffer_ratio = buffer_ratio
|
||||
|
||||
def query(self, images):
|
||||
"""Query current image batch using a history of generated images.
|
||||
Args:
|
||||
images (Tensor): Current image batch without history information.
|
||||
"""
|
||||
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
# if the buffer is not full, keep inserting current images
|
||||
if self.img_num < self.buffer_size:
|
||||
self.img_num = self.img_num + 1
|
||||
self.image_buffer.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||
# by self.buffer_ratio, the buffer will return a previously
|
||||
# stored image, and insert the current image into the buffer
|
||||
if use_buffer:
|
||||
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||
image_tmp = self.image_buffer[random_id].clone()
|
||||
self.image_buffer[random_id] = image
|
||||
return_images.append(image_tmp)
|
||||
# by (1 - self.buffer_ratio), the buffer will return the
|
||||
# current image
|
||||
else:
|
||||
return_images.append(image)
|
||||
# collect all the images and return
|
||||
return_images = torch.cat(return_images, 0)
|
||||
return return_images
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN"):
|
||||
super(ResGenerator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
kernel_size = 4
|
||||
padding = 1
|
||||
sequence = [
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size,
|
||||
padding=padding, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** num_conv, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1,
|
||||
padding=padding, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding)
|
||||
]
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@ -1,7 +1,6 @@
|
||||
from model.registry import MODEL
|
||||
import model.GAN.residual_generator
|
||||
import model.GAN.CycleGAN
|
||||
import model.GAN.TAFG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
import model.GAN.wrapper
|
||||
import model.GAN.base
|
||||
|
||||
105
model/fewshot.py
105
model/fewshot.py
@ -1,105 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import MODEL
|
||||
|
||||
|
||||
# --- gaussian initialize ---
|
||||
def init_layer(l):
|
||||
# Initialization using fan-in
|
||||
if isinstance(l, nn.Conv2d):
|
||||
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
|
||||
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
|
||||
elif isinstance(l, nn.BatchNorm2d):
|
||||
l.weight.data.fill_(1)
|
||||
l.bias.data.fill_(0)
|
||||
elif isinstance(l, nn.Linear):
|
||||
l.bias.data.fill_(0)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class SimpleBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
|
||||
super(SimpleBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
o = self.block(x)
|
||||
return self.relu(o + self.shortcut(x))
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
super().__init__()
|
||||
assert len(layers) == 4, 'Can have only four stages'
|
||||
self.inplanes = 64
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(self.inplanes),
|
||||
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
)
|
||||
|
||||
trunk = []
|
||||
in_channels = self.inplanes
|
||||
for i in range(4):
|
||||
for j in range(layers[i]):
|
||||
half_res = i >= 1 and j == 0
|
||||
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
|
||||
in_channels = dims[i]
|
||||
if flatten:
|
||||
trunk.append(nn.AvgPool2d(7))
|
||||
trunk.append(Flatten())
|
||||
if num_classes is not None:
|
||||
if classifier_type == "linear":
|
||||
trunk.append(nn.Linear(in_channels, num_classes))
|
||||
elif classifier_type == "distlinear":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid classifier_type:{classifier_type}")
|
||||
self.trunk = nn.Sequential(*trunk)
|
||||
self.apply(init_layer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.trunk(self.start(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
Loading…
Reference in New Issue
Block a user