add U-GAT-IT

This commit is contained in:
Ray Wong 2020-08-21 16:14:30 +08:00
parent 323bf2f6ab
commit 1a1cb9b00f
18 changed files with 815 additions and 55 deletions

View File

@ -25,7 +25,7 @@ baseline:
_type: Adam _type: Adam
data: data:
dataloader: dataloader:
batch_size: 1024 batch_size: 1200
shuffle: True shuffle: True
num_workers: 16 num_workers: 16
pin_memory: True pin_memory: True
@ -37,7 +37,7 @@ baseline:
pipeline: pipeline:
- Load - Load
- RandomResizedCrop: - RandomResizedCrop:
size: [256, 256] size: [224, 224]
- ColorJitter: - ColorJitter:
brightness: 0.4 brightness: 0.4
contrast: 0.4 contrast: 0.4
@ -47,20 +47,5 @@ baseline:
- Normalize: - Normalize:
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
val:
path: /data/few-shot/mini_imagenet_full_size/val
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb
pipeline:
- Load
- Resize:
size: [286, 286]
- RandomCrop:
size: [256, 256]
- ToTensor
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

View File

@ -0,0 +1,110 @@
name: selfie2anime
engine: UGATIT
result_dir: ./result
max_iteration: 100000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoints:
interval: 1000
model:
generator:
_type: UGATIT-Generator
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 4
img_size: 256
light: True
local_discriminator:
_type: UGATIT-Discriminator
in_channels: 3
base_channels: 64
num_blocks: 3
global_discriminator:
_type: UGATIT-Discriminator
in_channels: 3
base_channels: 64
num_blocks: 5
loss:
gan:
loss_type: lsgan
weight: 1.0
real_label_val: 1.0
fake_label_val: 0.0
cycle:
level: 1
weight: 10.0
id:
level: 1
weight: 10.0
cam:
weight: 1000
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [0.5, 0.999]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [0.5, 0.999]
weight_decay: 0.0001
data:
train:
buffer_size: 50
dataloader:
batch_size: 8
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/selfie2anime/trainA"
root_b: "/data/i2i/selfie2anime/trainB"
random_pair: True
pipeline:
- Load
- Resize:
size: [286, 286]
- RandomCrop:
size: [256, 256]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
scheduler:
start: 50000
target_lr: 0
test:
dataloader:
batch_size: 4
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/selfie2anime/testA"
root_b: "/data/i2i/selfie2anime/testB"
random_pair: False
pipeline:
- Load
- Resize:
size: [256, 256]
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]

View File

@ -99,9 +99,9 @@ class EpisodicDataset(Dataset):
def __getitem__(self, _): def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = [] support_set = []
query_set_list = [] query_set = []
target_list = [] target_set = []
for tag, c in enumerate(random_classes): for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c] image_list = self.origin.classes_list[c]
@ -113,13 +113,13 @@ class EpisodicDataset(Dataset):
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set_list.extend(support) support_set.extend(support)
query_set_list.extend(query) query_set.extend(query)
target_list.extend([tag] * self.num_query) target_set.extend([tag] * self.num_query)
return { return {
"support": torch.stack(support_set_list), "support": torch.stack(support_set),
"query": torch.stack(query_set_list), "query": torch.stack(query_set),
"target": torch.tensor(target_list) "target": torch.tensor(target_set)
} }
def __repr__(self): def __repr__(self):

249
engine/UGATIT.py Normal file
View File

@ -0,0 +1,249 @@
from itertools import chain
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generators = dict(
a2b=build_model(config.model.generator, config.distributed.model),
b2a=build_model(config.model.generator, config.distributed.model),
)
discriminators = dict(
la=build_model(config.model.local_discriminator, config.distributed.model),
lb=build_model(config.model.local_discriminator, config.distributed.model),
ga=build_model(config.model.global_discriminator, config.distributed.model),
gb=build_model(config.model.global_discriminator, config.distributed.model),
)
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
logger.debug(discriminators["ga"])
logger.debug(generators["a2b"])
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
config.optimizers.discriminator)
milestones_values = [
(0, config.optimizers.generator.lr),
(config.data.train.scheduler.start, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
image_buffers = {
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
discriminators.keys()}
rho_clipper = RhoClipper(0, 1)
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g):
discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake)
pred_fake_l, cam_ld_pred = discriminator_l(fake)
return {
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
}
def cal_discriminator_loss(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2.
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake = dict()
cam_generator_pred = dict()
rec = dict()
identity = dict()
cam_identity_pred = dict()
heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
optimizer_g.zero_grad()
loss_g = dict()
for n in ["a", "b"]:
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward()
optimizer_g.step()
for generator in generators.values():
generator.apply(rho_clipper)
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizer_d.zero_grad()
loss_d = dict()
for k in discriminators.keys():
n = k[-1] # "a" or "b"
loss_d.update(
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward()
optimizer_d.step()
for h in heatmap:
heatmap[h] = heatmap[h].detach()
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": {
"heatmap": heatmap,
"generated": generated_img
}
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
lr_scheduler_g=lr_scheduler_g)
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
engine.state.iteration)
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@ -17,7 +17,7 @@ from data.transform import transform_pipeline
from data.dataset import LMDBDataset from data.dataset import LMDBDataset
def baseline_trainer(config, logger): def warmup_trainer(config, logger):
model = build_model(config.model, config.distributed.model) model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers) optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
@ -66,18 +66,20 @@ def run(task, config, logger):
assert torch.backends.cudnn.enabled assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}") logger.info(f"start task {task}")
if task == "baseline": if task == "warmup":
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path, train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline) pipeline=config.baseline.data.dataset.train.pipeline)
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
logger.info(f"train with dataset:\n{train_dataset}") logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader) train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
trainer = baseline_trainer(config, logger) trainer = warmup_trainer(config, logger)
try: try:
trainer.run(train_data_loader, max_epochs=400) trainer.run(train_data_loader, max_epochs=400)
except Exception: except Exception:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())
elif task == "protonet-wo":
pass
elif task == "protonet-w":
pass
else: else:
return NotImplemented(f"invalid task: {task}") return ValueError(f"invalid task: {task}")

View File

@ -18,7 +18,7 @@ from omegaconf import OmegaConf
import data import data
from loss.gan import GANLoss from loss.gan import GANLoss
from model.weight_init import generation_init_weights from model.weight_init import generation_init_weights
from model.residual_generator import GANImageBuffer from model.GAN.residual_generator import GANImageBuffer
from util.image import make_2d_grid from util.image import make_2d_grid
from util.handler import setup_common_handlers from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer from util.build import build_model, build_optimizer
@ -31,8 +31,8 @@ def get_trainer(config, logger):
discriminator_b = build_model(config.model.discriminator, config.distributed.model) discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]: for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m) generation_init_weights(m)
logger.debug(discriminator_a) logger.info(discriminator_a)
logger.debug(generator_a) logger.info(generator_a)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator) config.optimizers.generator)
@ -56,8 +56,8 @@ def get_trainer(config, logger):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
@ -93,11 +93,11 @@ def get_trainer(config, logger):
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True), fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
) )
(sum(loss_d_a.values()) * 0.5).backward()
loss_d_b = dict( loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True), fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
) )
(sum(loss_d_a.values()) * 0.5).backward()
(sum(loss_d_b.values()) * 0.5).backward() (sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step() optimizer_d.step()

9
engine/fewshot.py Normal file
View File

@ -0,0 +1,9 @@
from data.dataset import EpisodicDataset, LMDBDataset
def prototypical_trainer(config, logger):
pass
def prototypical_dataloader(config):
pass

0
loss/fewshot/__init__.py Normal file
View File

View File

@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class PrototypicalLoss(nn.Module):
def __init__(self):
super().__init__()
@staticmethod
def acc(query, target, support):
prototypes = support.mean(-2) # batch_size x N_class x D
distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
indices = distance.argmin(-1) # smallest distance indices
acc = torch.eq(target, indices).float().mean().item()
return acc
@staticmethod
def euclidean_dist(x, y):
# x: B x N x D
# y: B x M x D
assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0)
n = x.size(-2)
m = y.size(-2)
d = x.size(-1)
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
return torch.pow(x - y, 2).sum(-1) # B x N x M
def forward(self, query, target, support):
"""
calculate prototypical loss
:param query: Tensor - batch_size x N_class*N_query x D
:param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class)
:param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id
:return: loss item and accuracy
"""
prototypes = support.mean(-2) # batch_size x N_class x D
distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
indices = distance.argmin(-1) # smallest distance indices
acc = torch.eq(target, indices).float().mean().item()
log_p_y = F.log_softmax(-distance, dim=-1)
n_class = support.size(1)
n_query = query.size(1) // n_class
batch_size = support.size(0)
target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp(
n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1)
loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean
return loss, acc

20
main.py
View File

@ -5,7 +5,8 @@ import torch
import ignite import ignite
import ignite.distributed as idist import ignite.distributed as idist
from ignite.utils import manual_seed, setup_logger from ignite.utils import manual_seed
from util.misc import setup_logger
import fire import fire
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -21,14 +22,12 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
log_basic_info(logger, config)
if setup_random_seed: if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank()) manual_seed(config.misc.random_seed + idist.get_rank())
if setup_output_dir: output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir config.output_dir = str(output_dir)
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:
if output_dir.exists(): if output_dir.exists():
# assert not any(output_dir.iterdir()), "output_dir must be empty" # assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir()) contains = list(output_dir.iterdir())
@ -37,11 +36,14 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
else: else:
if idist.get_rank() == 0: if idist.get_rank() == 0:
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
logger.info(f"mkdir -p {output_dir}") print(f"mkdir -p {output_dir}")
logger.info(f"output path: {config.output_dir}")
if backup_config and idist.get_rank() == 0: if backup_config and idist.get_rank() == 0:
with open(output_dir / "config.yml", "w+") as f: with open(output_dir / "config.yml", "w+") as f:
print(config.pretty(), file=f) print(config.pretty(), file=f)
logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log")
logger.info(f"output path: {config.output_dir}")
log_basic_info(logger, config)
OmegaConf.set_readonly(config, True) OmegaConf.set_readonly(config, True)

253
model/GAN/UGATIT.py Normal file
View File

@ -0,0 +1,253 @@
import torch
import torch.nn as nn
from .residual_generator import ResidualBlock
from model.registry import MODEL
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
assert (num_blocks >= 0)
super(Generator, self).__init__()
self.input_channels = in_channels
self.output_channels = out_channels
self.base_channels = base_channels
self.num_blocks = num_blocks
self.img_size = img_size
self.light = light
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.InstanceNorm2d(base_channels),
nn.ReLU(True)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
padding=1, bias=False, padding_mode="reflect"),
nn.InstanceNorm2d(base_channels * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2 ** n_down_sampling
for i in range(num_blocks):
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder)
# Class Activation Map
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
else:
fc = [
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
self.fc = nn.Sequential(*fc)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
# Up-Sampling
up_decoder = []
for i in range(n_down_sampling):
mult = 2 ** (n_down_sampling - i)
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
padding=1, padding_mode="reflect", bias=False),
ILN(base_channels * mult // 2),
nn.ReLU(True)]
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder)
# self.up_decoder = nn.ModuleDict({
# "up_1": nn.Upsample(scale_factor=2, mode='nearest'),
# "up_conv_1": nn.Sequential(
# nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1,
# padding=1, padding_mode="reflect", bias=False),
# ILN(base_channels * 4 // 2),
# nn.ReLU(True)),
# "up_2": nn.Upsample(scale_factor=2, mode='nearest'),
# "up_conv_2": nn.Sequential(
# nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1,
# padding=1, padding_mode="reflect", bias=False),
# ILN(base_channels * 2 // 2),
# nn.ReLU(True)),
# "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
# padding_mode="reflect", bias=False), nn.Tanh())
# })
def forward(self, x):
x = self.down_encoder(x)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for ub in self.up_bottleneck:
x = ub(x, gamma, beta)
x = self.up_decoder(x)
return x, cam_logit, heatmap
class ResnetAdaILNBlock(nn.Module):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.conv1(x)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(default_rho)
def forward(self, x, gamma, beta):
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
self.beta = nn.Parameter(torch.Tensor(1, num_features))
self.rho.data.fill_(0.0)
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, x):
return instance_layer_normalization(
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5):
super(Discriminator, self).__init__()
encoder = [self.build_conv_block(in_channels, base_channels)]
for i in range(1, num_blocks - 2):
mult = 2 ** (i - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
mult = 2 ** (num_blocks - 2 - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
self.encoder = nn.Sequential(*encoder)
# Class Activation Map
mult = 2 ** (num_blocks - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
@staticmethod
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
return nn.Sequential(*[
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
bias=True, padding=padding, padding_mode=padding_mode)),
nn.LeakyReLU(0.2, True),
])
def forward(self, x, return_heatmap=False):
x = self.encoder(x)
batch_size = x.size(0)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
gap_logit = self.gap_fc(gap.view(batch_size, -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

0
model/GAN/__init__.py Normal file
View File

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import functools import functools
from .registry import MODEL from model.registry import MODEL
def _select_norm_layer(norm_type): def _select_norm_layer(norm_type):
@ -71,11 +71,12 @@ class GANImageBuffer(object):
@MODEL.register_module() @MODEL.register_module()
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False): def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
# Only for IN, use bias since it does not have affine parameters. if use_bias is None:
use_bias = norm_type == "IN" # Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = _select_norm_layer(norm_type) norm_layer = _select_norm_layer(norm_type)
models = [nn.Sequential( models = [nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),

View File

@ -1,3 +1,3 @@
from model.registry import MODEL from model.registry import MODEL
import model.residual_generator import model.GAN.residual_generator
import model.fewshot import model.fewshot

8
run.sh
View File

@ -3,12 +3,18 @@
CONFIG=$1 CONFIG=$1
TASK=$2 TASK=$2
GPUS=$3 GPUS=$3
MORE_ARG=${*:4}
_command="print(len('${GPUS}'.split(',')))" _command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}") GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}" echo "GPU_COUNT:${GPU_COUNT}"
echo CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
CUDA_VISIBLE_DEVICES=$GPUS \ CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed

View File

@ -39,6 +39,7 @@ def setup_common_handlers(
:param checkpoint_kwargs: :param checkpoint_kwargs:
:return: :return:
""" """
@trainer.on(Events.STARTED) @trainer.on(Events.STARTED)
@idist.one_rank_only() @idist.one_rank_only()
def print_dataloader_size(engine): def print_dataloader_size(engine):
@ -79,6 +80,8 @@ def setup_common_handlers(
engine.logger.info(print_str) engine.logger.info(print_str)
if to_save is not None: if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
**checkpoint_kwargs)
if resume_from is not None: if resume_from is not None:
@trainer.on(Events.STARTED) @trainer.on(Events.STARTED)
def resume(engine): def resume(engine):
@ -89,5 +92,4 @@ def setup_common_handlers(
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}") engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None: if save_interval_event is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
trainer.add_event_handler(save_interval_event, checkpoint_handler) trainer.add_event_handler(save_interval_event, checkpoint_handler)

85
util/misc.py Normal file
View File

@ -0,0 +1,85 @@
import logging
from typing import Optional
def setup_logger(
name: Optional[str] = None,
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
Args:
name (str, optional): new name for the logger. If None, the standard logger is used.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.
Returns:
logging.Logger
For example, to improve logs readability when training with a trainer and evaluator:
.. code-block:: python
from ignite.utils import setup_logger
trainer = ...
evaluator = ...
trainer.logger = setup_logger("trainer")
evaluator.logger = setup_logger("evaluator")
trainer.run(data, max_epochs=10)
# Logs will look like
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
# ...
"""
logger = logging.getLogger(name)
# don't propagate to ancestors
# the problem here is to attach handlers to loggers
# should we provide a default configuration less open ?
if name is not None:
logger.propagate = False
# Remove previous handlers
if logger.hasHandlers():
for h in list(logger.handlers):
logger.removeHandler(h)
formatter = logging.Formatter(logger_format)
if distributed_rank is None:
import ignite.distributed as idist
distributed_rank = idist.get_rank()
if distributed_rank > 0:
logger.addHandler(logging.NullHandler())
else:
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None:
fh = logging.FileHandler(filepath)
fh.setLevel(file_level)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger

View File

@ -67,7 +67,11 @@ class _Registry:
if default_args is not None: if default_args is not None:
for name, value in default_args.items(): for name, value in default_args.items():
args.setdefault(name, value) args.setdefault(name, value)
return obj_cls(**args) try:
obj = obj_cls(**args)
except TypeError as e:
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
return obj
class ModuleRegistry(_Registry): class ModuleRegistry(_Registry):