add U-GAT-IT
This commit is contained in:
parent
323bf2f6ab
commit
1a1cb9b00f
@ -25,7 +25,7 @@ baseline:
|
|||||||
_type: Adam
|
_type: Adam
|
||||||
data:
|
data:
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 1024
|
batch_size: 1200
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 16
|
num_workers: 16
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
@ -37,7 +37,7 @@ baseline:
|
|||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
- RandomResizedCrop:
|
- RandomResizedCrop:
|
||||||
size: [256, 256]
|
size: [224, 224]
|
||||||
- ColorJitter:
|
- ColorJitter:
|
||||||
brightness: 0.4
|
brightness: 0.4
|
||||||
contrast: 0.4
|
contrast: 0.4
|
||||||
@ -47,20 +47,5 @@ baseline:
|
|||||||
- Normalize:
|
- Normalize:
|
||||||
mean: [0.485, 0.456, 0.406]
|
mean: [0.485, 0.456, 0.406]
|
||||||
std: [0.229, 0.224, 0.225]
|
std: [0.229, 0.224, 0.225]
|
||||||
val:
|
|
||||||
path: /data/few-shot/mini_imagenet_full_size/val
|
|
||||||
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb
|
|
||||||
pipeline:
|
|
||||||
- Load
|
|
||||||
- Resize:
|
|
||||||
size: [286, 286]
|
|
||||||
- RandomCrop:
|
|
||||||
size: [256, 256]
|
|
||||||
- ToTensor
|
|
||||||
- Normalize:
|
|
||||||
mean: [0.485, 0.456, 0.406]
|
|
||||||
std: [0.229, 0.224, 0.225]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
110
configs/synthesizers/UGATIT.yml
Normal file
110
configs/synthesizers/UGATIT.yml
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
name: selfie2anime
|
||||||
|
engine: UGATIT
|
||||||
|
result_dir: ./result
|
||||||
|
max_iteration: 100000
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
model:
|
||||||
|
# broadcast_buffers: False
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 324
|
||||||
|
|
||||||
|
checkpoints:
|
||||||
|
interval: 1000
|
||||||
|
|
||||||
|
model:
|
||||||
|
generator:
|
||||||
|
_type: UGATIT-Generator
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
num_blocks: 4
|
||||||
|
img_size: 256
|
||||||
|
light: True
|
||||||
|
local_discriminator:
|
||||||
|
_type: UGATIT-Discriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
num_blocks: 3
|
||||||
|
global_discriminator:
|
||||||
|
_type: UGATIT-Discriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
num_blocks: 5
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: lsgan
|
||||||
|
weight: 1.0
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
cycle:
|
||||||
|
level: 1
|
||||||
|
weight: 10.0
|
||||||
|
id:
|
||||||
|
level: 1
|
||||||
|
weight: 10.0
|
||||||
|
cam:
|
||||||
|
weight: 1000
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [0.5, 0.999]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 1e-4
|
||||||
|
betas: [0.5, 0.999]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
buffer_size: 50
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 2
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/i2i/selfie2anime/trainA"
|
||||||
|
root_b: "/data/i2i/selfie2anime/trainB"
|
||||||
|
random_pair: True
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [286, 286]
|
||||||
|
- RandomCrop:
|
||||||
|
size: [256, 256]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [0.5, 0.5, 0.5]
|
||||||
|
std: [0.5, 0.5, 0.5]
|
||||||
|
scheduler:
|
||||||
|
start: 50000
|
||||||
|
target_lr: 0
|
||||||
|
test:
|
||||||
|
dataloader:
|
||||||
|
batch_size: 4
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: GenerationUnpairedDataset
|
||||||
|
root_a: "/data/i2i/selfie2anime/testA"
|
||||||
|
root_b: "/data/i2i/selfie2anime/testB"
|
||||||
|
random_pair: False
|
||||||
|
pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [256, 256]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [0.5, 0.5, 0.5]
|
||||||
|
std: [0.5, 0.5, 0.5]
|
||||||
@ -99,9 +99,9 @@ class EpisodicDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, _):
|
def __getitem__(self, _):
|
||||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||||
support_set_list = []
|
support_set = []
|
||||||
query_set_list = []
|
query_set = []
|
||||||
target_list = []
|
target_set = []
|
||||||
for tag, c in enumerate(random_classes):
|
for tag, c in enumerate(random_classes):
|
||||||
image_list = self.origin.classes_list[c]
|
image_list = self.origin.classes_list[c]
|
||||||
|
|
||||||
@ -113,13 +113,13 @@ class EpisodicDataset(Dataset):
|
|||||||
|
|
||||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||||
support_set_list.extend(support)
|
support_set.extend(support)
|
||||||
query_set_list.extend(query)
|
query_set.extend(query)
|
||||||
target_list.extend([tag] * self.num_query)
|
target_set.extend([tag] * self.num_query)
|
||||||
return {
|
return {
|
||||||
"support": torch.stack(support_set_list),
|
"support": torch.stack(support_set),
|
||||||
"query": torch.stack(query_set_list),
|
"query": torch.stack(query_set),
|
||||||
"target": torch.tensor(target_list)
|
"target": torch.tensor(target_set)
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
249
engine/UGATIT.py
Normal file
249
engine/UGATIT.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.utils
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
|
from ignite.engine import Events, Engine
|
||||||
|
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||||
|
from ignite.metrics import RunningAverage
|
||||||
|
from ignite.contrib.handlers import ProgressBar
|
||||||
|
from ignite.utils import convert_tensor
|
||||||
|
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
import data
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
from model.GAN.residual_generator import GANImageBuffer
|
||||||
|
from model.GAN.UGATIT import RhoClipper
|
||||||
|
from util.image import make_2d_grid
|
||||||
|
from util.handler import setup_common_handlers
|
||||||
|
from util.build import build_model, build_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_trainer(config, logger):
|
||||||
|
generators = dict(
|
||||||
|
a2b=build_model(config.model.generator, config.distributed.model),
|
||||||
|
b2a=build_model(config.model.generator, config.distributed.model),
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
la=build_model(config.model.local_discriminator, config.distributed.model),
|
||||||
|
lb=build_model(config.model.local_discriminator, config.distributed.model),
|
||||||
|
ga=build_model(config.model.global_discriminator, config.distributed.model),
|
||||||
|
gb=build_model(config.model.global_discriminator, config.distributed.model),
|
||||||
|
)
|
||||||
|
for m in chain(generators.values(), discriminators.values()):
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
logger.debug(discriminators["ga"])
|
||||||
|
logger.debug(generators["a2b"])
|
||||||
|
|
||||||
|
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
|
||||||
|
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
|
||||||
|
config.optimizers.discriminator)
|
||||||
|
|
||||||
|
milestones_values = [
|
||||||
|
(0, config.optimizers.generator.lr),
|
||||||
|
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||||
|
|
||||||
|
milestones_values = [
|
||||||
|
(0, config.optimizers.discriminator.lr),
|
||||||
|
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||||
|
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||||
|
]
|
||||||
|
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
|
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
|
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
|
||||||
|
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||||
|
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||||
|
|
||||||
|
image_buffers = {
|
||||||
|
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
|
||||||
|
discriminators.keys()}
|
||||||
|
|
||||||
|
rho_clipper = RhoClipper(0, 1)
|
||||||
|
|
||||||
|
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||||
|
discriminator_g):
|
||||||
|
discriminator_g.requires_grad_(False)
|
||||||
|
discriminator_l.requires_grad_(False)
|
||||||
|
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||||
|
pred_fake_l, cam_ld_pred = discriminator_l(fake)
|
||||||
|
return {
|
||||||
|
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
|
||||||
|
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
|
||||||
|
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
|
||||||
|
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
|
||||||
|
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
|
||||||
|
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
|
||||||
|
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def cal_discriminator_loss(name, discriminator, real, fake):
|
||||||
|
pred_real, cam_real = discriminator(real)
|
||||||
|
pred_fake, cam_fake = discriminator(fake)
|
||||||
|
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||||
|
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
|
||||||
|
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||||
|
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||||
|
|
||||||
|
def _step(engine, batch):
|
||||||
|
batch = convert_tensor(batch, idist.device())
|
||||||
|
real_a, real_b = batch["a"], batch["b"]
|
||||||
|
|
||||||
|
fake = dict()
|
||||||
|
cam_generator_pred = dict()
|
||||||
|
rec = dict()
|
||||||
|
identity = dict()
|
||||||
|
cam_identity_pred = dict()
|
||||||
|
heatmap = dict()
|
||||||
|
|
||||||
|
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
|
||||||
|
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
|
||||||
|
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||||
|
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||||
|
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
|
||||||
|
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
|
||||||
|
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
loss_g = dict()
|
||||||
|
for n in ["a", "b"]:
|
||||||
|
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||||
|
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||||
|
sum(loss_g.values()).backward()
|
||||||
|
optimizer_g.step()
|
||||||
|
for generator in generators.values():
|
||||||
|
generator.apply(rho_clipper)
|
||||||
|
for discriminator in discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
loss_d = dict()
|
||||||
|
for k in discriminators.keys():
|
||||||
|
n = k[-1] # "a" or "b"
|
||||||
|
loss_d.update(
|
||||||
|
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
|
||||||
|
sum(loss_d.values()).backward()
|
||||||
|
optimizer_d.step()
|
||||||
|
|
||||||
|
for h in heatmap:
|
||||||
|
heatmap[h] = heatmap[h].detach()
|
||||||
|
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
|
||||||
|
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||||
|
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": {
|
||||||
|
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||||
|
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||||
|
},
|
||||||
|
"img": {
|
||||||
|
"heatmap": heatmap,
|
||||||
|
"generated": generated_img
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = Engine(_step)
|
||||||
|
trainer.logger = logger
|
||||||
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||||
|
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||||
|
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||||
|
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||||
|
|
||||||
|
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
|
||||||
|
lr_scheduler_g=lr_scheduler_g)
|
||||||
|
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||||
|
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||||
|
|
||||||
|
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||||
|
filename_prefix=config.name, to_save=to_save,
|
||||||
|
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||||
|
metrics_to_print=["loss_g", "loss_d"],
|
||||||
|
save_interval_event=Events.ITERATION_COMPLETED(
|
||||||
|
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||||
|
def terminate(engine):
|
||||||
|
engine.terminate()
|
||||||
|
|
||||||
|
if idist.get_rank() == 0:
|
||||||
|
# Create a logger
|
||||||
|
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||||
|
tb_writer = tb_logger.writer
|
||||||
|
|
||||||
|
# Attach the logger to the trainer to log training loss at each iteration
|
||||||
|
def global_step_transform(*args, **kwargs):
|
||||||
|
return trainer.state.iteration
|
||||||
|
|
||||||
|
def output_transform(output):
|
||||||
|
loss = dict()
|
||||||
|
for tl in output["loss"]:
|
||||||
|
if isinstance(output["loss"][tl], dict):
|
||||||
|
for l in output["loss"][tl]:
|
||||||
|
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||||
|
else:
|
||||||
|
loss[tl] = output["loss"][tl]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
tb_logger.attach(
|
||||||
|
trainer,
|
||||||
|
log_handler=OutputHandler(
|
||||||
|
tag="loss",
|
||||||
|
metric_names=["loss_g", "loss_d"],
|
||||||
|
global_step_transform=global_step_transform,
|
||||||
|
output_transform=output_transform
|
||||||
|
),
|
||||||
|
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||||
|
)
|
||||||
|
tb_logger.attach(
|
||||||
|
trainer,
|
||||||
|
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||||
|
event_name=Events.ITERATION_STARTED(every=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||||
|
def show_images(engine):
|
||||||
|
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
|
||||||
|
engine.state.iteration)
|
||||||
|
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
|
||||||
|
engine.state.iteration)
|
||||||
|
|
||||||
|
@trainer.on(Events.COMPLETED)
|
||||||
|
@idist.one_rank_only()
|
||||||
|
def _():
|
||||||
|
# We need to close the logger with we are done
|
||||||
|
tb_logger.close()
|
||||||
|
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, logger):
|
||||||
|
assert torch.backends.cudnn.enabled
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger.info(f"start task {task}")
|
||||||
|
if task == "train":
|
||||||
|
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||||
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||||
|
trainer = get_trainer(config, logger)
|
||||||
|
try:
|
||||||
|
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
else:
|
||||||
|
return NotImplemented(f"invalid task: {task}")
|
||||||
@ -17,7 +17,7 @@ from data.transform import transform_pipeline
|
|||||||
from data.dataset import LMDBDataset
|
from data.dataset import LMDBDataset
|
||||||
|
|
||||||
|
|
||||||
def baseline_trainer(config, logger):
|
def warmup_trainer(config, logger):
|
||||||
model = build_model(config.model, config.distributed.model)
|
model = build_model(config.model, config.distributed.model)
|
||||||
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
@ -66,18 +66,20 @@ def run(task, config, logger):
|
|||||||
assert torch.backends.cudnn.enabled
|
assert torch.backends.cudnn.enabled
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
logger.info(f"start task {task}")
|
logger.info(f"start task {task}")
|
||||||
if task == "baseline":
|
if task == "warmup":
|
||||||
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
||||||
pipeline=config.baseline.data.dataset.train.pipeline)
|
pipeline=config.baseline.data.dataset.train.pipeline)
|
||||||
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
|
|
||||||
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
|
|
||||||
logger.info(f"train with dataset:\n{train_dataset}")
|
logger.info(f"train with dataset:\n{train_dataset}")
|
||||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
||||||
trainer = baseline_trainer(config, logger)
|
trainer = warmup_trainer(config, logger)
|
||||||
try:
|
try:
|
||||||
trainer.run(train_data_loader, max_epochs=400)
|
trainer.run(train_data_loader, max_epochs=400)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
elif task == "protonet-wo":
|
||||||
|
pass
|
||||||
|
elif task == "protonet-w":
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
return NotImplemented(f"invalid task: {task}")
|
return ValueError(f"invalid task: {task}")
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from omegaconf import OmegaConf
|
|||||||
import data
|
import data
|
||||||
from loss.gan import GANLoss
|
from loss.gan import GANLoss
|
||||||
from model.weight_init import generation_init_weights
|
from model.weight_init import generation_init_weights
|
||||||
from model.residual_generator import GANImageBuffer
|
from model.GAN.residual_generator import GANImageBuffer
|
||||||
from util.image import make_2d_grid
|
from util.image import make_2d_grid
|
||||||
from util.handler import setup_common_handlers
|
from util.handler import setup_common_handlers
|
||||||
from util.build import build_model, build_optimizer
|
from util.build import build_model, build_optimizer
|
||||||
@ -31,8 +31,8 @@ def get_trainer(config, logger):
|
|||||||
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
||||||
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
||||||
generation_init_weights(m)
|
generation_init_weights(m)
|
||||||
logger.debug(discriminator_a)
|
logger.info(discriminator_a)
|
||||||
logger.debug(generator_a)
|
logger.info(generator_a)
|
||||||
|
|
||||||
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||||
config.optimizers.generator)
|
config.optimizers.generator)
|
||||||
@ -56,8 +56,8 @@ def get_trainer(config, logger):
|
|||||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
gan_loss_cfg.pop("weight")
|
gan_loss_cfg.pop("weight")
|
||||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||||
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||||
@ -93,11 +93,11 @@ def get_trainer(config, logger):
|
|||||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||||
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
||||||
)
|
)
|
||||||
(sum(loss_d_a.values()) * 0.5).backward()
|
|
||||||
loss_d_b = dict(
|
loss_d_b = dict(
|
||||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||||
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
||||||
)
|
)
|
||||||
|
(sum(loss_d_a.values()) * 0.5).backward()
|
||||||
(sum(loss_d_b.values()) * 0.5).backward()
|
(sum(loss_d_b.values()) * 0.5).backward()
|
||||||
optimizer_d.step()
|
optimizer_d.step()
|
||||||
|
|
||||||
|
|||||||
9
engine/fewshot.py
Normal file
9
engine/fewshot.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from data.dataset import EpisodicDataset, LMDBDataset
|
||||||
|
|
||||||
|
|
||||||
|
def prototypical_trainer(config, logger):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prototypical_dataloader(config):
|
||||||
|
pass
|
||||||
0
loss/fewshot/__init__.py
Normal file
0
loss/fewshot/__init__.py
Normal file
52
loss/fewshot/prototypical.py
Normal file
52
loss/fewshot/prototypical.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypicalLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def acc(query, target, support):
|
||||||
|
prototypes = support.mean(-2) # batch_size x N_class x D
|
||||||
|
distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
||||||
|
indices = distance.argmin(-1) # smallest distance indices
|
||||||
|
acc = torch.eq(target, indices).float().mean().item()
|
||||||
|
return acc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def euclidean_dist(x, y):
|
||||||
|
# x: B x N x D
|
||||||
|
# y: B x M x D
|
||||||
|
assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0)
|
||||||
|
n = x.size(-2)
|
||||||
|
m = y.size(-2)
|
||||||
|
d = x.size(-1)
|
||||||
|
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
|
||||||
|
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
|
||||||
|
return torch.pow(x - y, 2).sum(-1) # B x N x M
|
||||||
|
|
||||||
|
def forward(self, query, target, support):
|
||||||
|
"""
|
||||||
|
calculate prototypical loss
|
||||||
|
:param query: Tensor - batch_size x N_class*N_query x D
|
||||||
|
:param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class)
|
||||||
|
:param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id
|
||||||
|
:return: loss item and accuracy
|
||||||
|
"""
|
||||||
|
|
||||||
|
prototypes = support.mean(-2) # batch_size x N_class x D
|
||||||
|
distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
||||||
|
indices = distance.argmin(-1) # smallest distance indices
|
||||||
|
acc = torch.eq(target, indices).float().mean().item()
|
||||||
|
|
||||||
|
log_p_y = F.log_softmax(-distance, dim=-1)
|
||||||
|
n_class = support.size(1)
|
||||||
|
n_query = query.size(1) // n_class
|
||||||
|
batch_size = support.size(0)
|
||||||
|
target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp(
|
||||||
|
n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1)
|
||||||
|
loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean
|
||||||
|
|
||||||
|
return loss, acc
|
||||||
20
main.py
20
main.py
@ -5,7 +5,8 @@ import torch
|
|||||||
|
|
||||||
import ignite
|
import ignite
|
||||||
import ignite.distributed as idist
|
import ignite.distributed as idist
|
||||||
from ignite.utils import manual_seed, setup_logger
|
from ignite.utils import manual_seed
|
||||||
|
from util.misc import setup_logger
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -21,14 +22,12 @@ def log_basic_info(logger, config):
|
|||||||
|
|
||||||
|
|
||||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||||
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
|
|
||||||
log_basic_info(logger, config)
|
|
||||||
|
|
||||||
if setup_random_seed:
|
if setup_random_seed:
|
||||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||||
if setup_output_dir:
|
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
config.output_dir = str(output_dir)
|
||||||
config.output_dir = str(output_dir)
|
|
||||||
|
if setup_output_dir and config.resume_from is None:
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||||
contains = list(output_dir.iterdir())
|
contains = list(output_dir.iterdir())
|
||||||
@ -37,11 +36,14 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
|||||||
else:
|
else:
|
||||||
if idist.get_rank() == 0:
|
if idist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
logger.info(f"mkdir -p {output_dir}")
|
print(f"mkdir -p {output_dir}")
|
||||||
logger.info(f"output path: {config.output_dir}")
|
|
||||||
if backup_config and idist.get_rank() == 0:
|
if backup_config and idist.get_rank() == 0:
|
||||||
with open(output_dir / "config.yml", "w+") as f:
|
with open(output_dir / "config.yml", "w+") as f:
|
||||||
print(config.pretty(), file=f)
|
print(config.pretty(), file=f)
|
||||||
|
logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log")
|
||||||
|
logger.info(f"output path: {config.output_dir}")
|
||||||
|
log_basic_info(logger, config)
|
||||||
|
|
||||||
OmegaConf.set_readonly(config, True)
|
OmegaConf.set_readonly(config, True)
|
||||||
|
|
||||||
|
|||||||
253
model/GAN/UGATIT.py
Normal file
253
model/GAN/UGATIT.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .residual_generator import ResidualBlock
|
||||||
|
from model.registry import MODEL
|
||||||
|
|
||||||
|
|
||||||
|
class RhoClipper(object):
|
||||||
|
def __init__(self, clip_min, clip_max):
|
||||||
|
self.clip_min = clip_min
|
||||||
|
self.clip_max = clip_max
|
||||||
|
assert clip_min < clip_max
|
||||||
|
|
||||||
|
def __call__(self, module):
|
||||||
|
if hasattr(module, 'rho'):
|
||||||
|
w = module.rho.data
|
||||||
|
w = w.clamp(self.clip_min, self.clip_max)
|
||||||
|
module.rho.data = w
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("UGATIT-Generator")
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
|
||||||
|
assert (num_blocks >= 0)
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.input_channels = in_channels
|
||||||
|
self.output_channels = out_channels
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.img_size = img_size
|
||||||
|
self.light = light
|
||||||
|
|
||||||
|
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
|
||||||
|
padding_mode="reflect", bias=False),
|
||||||
|
nn.InstanceNorm2d(base_channels),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
|
||||||
|
n_down_sampling = 2
|
||||||
|
for i in range(n_down_sampling):
|
||||||
|
mult = 2 ** i
|
||||||
|
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
|
||||||
|
padding=1, bias=False, padding_mode="reflect"),
|
||||||
|
nn.InstanceNorm2d(base_channels * mult * 2),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
|
||||||
|
# Down-Sampling Bottleneck
|
||||||
|
mult = 2 ** n_down_sampling
|
||||||
|
for i in range(num_blocks):
|
||||||
|
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
|
||||||
|
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
||||||
|
self.down_encoder = nn.Sequential(*down_encoder)
|
||||||
|
|
||||||
|
# Class Activation Map
|
||||||
|
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||||
|
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||||
|
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||||
|
self.relu = nn.ReLU(True)
|
||||||
|
|
||||||
|
# Gamma, Beta block
|
||||||
|
if self.light:
|
||||||
|
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
else:
|
||||||
|
fc = [
|
||||||
|
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
self.fc = nn.Sequential(*fc)
|
||||||
|
|
||||||
|
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||||
|
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||||
|
|
||||||
|
# Up-Sampling Bottleneck
|
||||||
|
self.up_bottleneck = nn.ModuleList(
|
||||||
|
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
|
||||||
|
|
||||||
|
# Up-Sampling
|
||||||
|
up_decoder = []
|
||||||
|
for i in range(n_down_sampling):
|
||||||
|
mult = 2 ** (n_down_sampling - i)
|
||||||
|
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
|
||||||
|
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
|
||||||
|
padding=1, padding_mode="reflect", bias=False),
|
||||||
|
ILN(base_channels * mult // 2),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
|
||||||
|
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||||
|
padding_mode="reflect", bias=False),
|
||||||
|
nn.Tanh()]
|
||||||
|
self.up_decoder = nn.Sequential(*up_decoder)
|
||||||
|
# self.up_decoder = nn.ModuleDict({
|
||||||
|
# "up_1": nn.Upsample(scale_factor=2, mode='nearest'),
|
||||||
|
# "up_conv_1": nn.Sequential(
|
||||||
|
# nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1,
|
||||||
|
# padding=1, padding_mode="reflect", bias=False),
|
||||||
|
# ILN(base_channels * 4 // 2),
|
||||||
|
# nn.ReLU(True)),
|
||||||
|
# "up_2": nn.Upsample(scale_factor=2, mode='nearest'),
|
||||||
|
# "up_conv_2": nn.Sequential(
|
||||||
|
# nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1,
|
||||||
|
# padding=1, padding_mode="reflect", bias=False),
|
||||||
|
# ILN(base_channels * 2 // 2),
|
||||||
|
# nn.ReLU(True)),
|
||||||
|
# "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||||
|
# padding_mode="reflect", bias=False), nn.Tanh())
|
||||||
|
# })
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.down_encoder(x)
|
||||||
|
|
||||||
|
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||||
|
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
||||||
|
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||||
|
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
||||||
|
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||||
|
|
||||||
|
x = torch.cat([gap, gmp], 1)
|
||||||
|
x = self.relu(self.conv1x1(x))
|
||||||
|
|
||||||
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
if self.light:
|
||||||
|
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||||
|
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||||
|
else:
|
||||||
|
x_ = self.fc(x.view(x.shape[0], -1))
|
||||||
|
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||||
|
|
||||||
|
for ub in self.up_bottleneck:
|
||||||
|
x = ub(x, gamma, beta)
|
||||||
|
|
||||||
|
x = self.up_decoder(x)
|
||||||
|
return x, cam_logit, heatmap
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetAdaILNBlock(nn.Module):
|
||||||
|
def __init__(self, dim, use_bias):
|
||||||
|
super(ResnetAdaILNBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||||
|
self.norm1 = AdaILN(dim)
|
||||||
|
self.relu1 = nn.ReLU(True)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||||
|
self.norm2 = AdaILN(dim)
|
||||||
|
|
||||||
|
def forward(self, x, gamma, beta):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.norm1(out, gamma, beta)
|
||||||
|
out = self.relu1(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.norm2(out, gamma, beta)
|
||||||
|
|
||||||
|
return out + x
|
||||||
|
|
||||||
|
|
||||||
|
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||||
|
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
|
||||||
|
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
|
||||||
|
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||||
|
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
|
||||||
|
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
|
||||||
|
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AdaILN(nn.Module):
|
||||||
|
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||||
|
super(AdaILN, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||||
|
self.rho.data.fill_(default_rho)
|
||||||
|
|
||||||
|
def forward(self, x, gamma, beta):
|
||||||
|
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class ILN(nn.Module):
|
||||||
|
def __init__(self, num_features, eps=1e-5):
|
||||||
|
super(ILN, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||||
|
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
|
||||||
|
self.beta = nn.Parameter(torch.Tensor(1, num_features))
|
||||||
|
self.rho.data.fill_(0.0)
|
||||||
|
self.gamma.data.fill_(1.0)
|
||||||
|
self.beta.data.fill_(0.0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return instance_layer_normalization(
|
||||||
|
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("UGATIT-Discriminator")
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
def __init__(self, in_channels, base_channels=64, num_blocks=5):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
encoder = [self.build_conv_block(in_channels, base_channels)]
|
||||||
|
|
||||||
|
for i in range(1, num_blocks - 2):
|
||||||
|
mult = 2 ** (i - 1)
|
||||||
|
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
|
||||||
|
|
||||||
|
mult = 2 ** (num_blocks - 2 - 1)
|
||||||
|
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(*encoder)
|
||||||
|
|
||||||
|
# Class Activation Map
|
||||||
|
mult = 2 ** (num_blocks - 2)
|
||||||
|
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||||
|
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||||
|
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||||
|
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
||||||
|
|
||||||
|
self.conv = nn.utils.spectral_norm(
|
||||||
|
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
|
||||||
|
return nn.Sequential(*[
|
||||||
|
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||||
|
bias=True, padding=padding, padding_mode=padding_mode)),
|
||||||
|
nn.LeakyReLU(0.2, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, return_heatmap=False):
|
||||||
|
x = self.encoder(x)
|
||||||
|
batch_size = x.size(0)
|
||||||
|
|
||||||
|
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
|
||||||
|
gap_logit = self.gap_fc(gap.view(batch_size, -1))
|
||||||
|
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||||
|
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
|
||||||
|
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||||
|
|
||||||
|
x = torch.cat([gap, gmp], 1)
|
||||||
|
x = self.leaky_relu(self.conv1x1(x))
|
||||||
|
|
||||||
|
if return_heatmap:
|
||||||
|
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||||
|
return self.conv(x), cam_logit, heatmap
|
||||||
|
else:
|
||||||
|
return self.conv(x), cam_logit
|
||||||
0
model/GAN/__init__.py
Normal file
0
model/GAN/__init__.py
Normal file
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import functools
|
import functools
|
||||||
from .registry import MODEL
|
from model.registry import MODEL
|
||||||
|
|
||||||
|
|
||||||
def _select_norm_layer(norm_type):
|
def _select_norm_layer(norm_type):
|
||||||
@ -71,11 +71,12 @@ class GANImageBuffer(object):
|
|||||||
|
|
||||||
@MODEL.register_module()
|
@MODEL.register_module()
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):
|
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
# Only for IN, use bias since it does not have affine parameters.
|
if use_bias is None:
|
||||||
use_bias = norm_type == "IN"
|
# Only for IN, use bias since it does not have affine parameters.
|
||||||
|
use_bias = norm_type == "IN"
|
||||||
norm_layer = _select_norm_layer(norm_type)
|
norm_layer = _select_norm_layer(norm_type)
|
||||||
models = [nn.Sequential(
|
models = [nn.Sequential(
|
||||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||||
@ -1,3 +1,3 @@
|
|||||||
from model.registry import MODEL
|
from model.registry import MODEL
|
||||||
import model.residual_generator
|
import model.GAN.residual_generator
|
||||||
import model.fewshot
|
import model.fewshot
|
||||||
|
|||||||
8
run.sh
8
run.sh
@ -3,12 +3,18 @@
|
|||||||
CONFIG=$1
|
CONFIG=$1
|
||||||
TASK=$2
|
TASK=$2
|
||||||
GPUS=$3
|
GPUS=$3
|
||||||
|
MORE_ARG=${*:4}
|
||||||
|
|
||||||
_command="print(len('${GPUS}'.split(',')))"
|
_command="print(len('${GPUS}'.split(',')))"
|
||||||
GPU_COUNT=$(python3 -c "${_command}")
|
GPU_COUNT=$(python3 -c "${_command}")
|
||||||
|
|
||||||
echo "GPU_COUNT:${GPU_COUNT}"
|
echo "GPU_COUNT:${GPU_COUNT}"
|
||||||
|
|
||||||
|
echo CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
|
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,7 @@ def setup_common_handlers(
|
|||||||
:param checkpoint_kwargs:
|
:param checkpoint_kwargs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@trainer.on(Events.STARTED)
|
@trainer.on(Events.STARTED)
|
||||||
@idist.one_rank_only()
|
@idist.one_rank_only()
|
||||||
def print_dataloader_size(engine):
|
def print_dataloader_size(engine):
|
||||||
@ -79,6 +80,8 @@ def setup_common_handlers(
|
|||||||
engine.logger.info(print_str)
|
engine.logger.info(print_str)
|
||||||
|
|
||||||
if to_save is not None:
|
if to_save is not None:
|
||||||
|
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
|
||||||
|
**checkpoint_kwargs)
|
||||||
if resume_from is not None:
|
if resume_from is not None:
|
||||||
@trainer.on(Events.STARTED)
|
@trainer.on(Events.STARTED)
|
||||||
def resume(engine):
|
def resume(engine):
|
||||||
@ -89,5 +92,4 @@ def setup_common_handlers(
|
|||||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||||
if save_interval_event is not None:
|
if save_interval_event is not None:
|
||||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
|
|
||||||
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
||||||
|
|||||||
85
util/misc.py
Normal file
85
util/misc.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
level: int = logging.INFO,
|
||||||
|
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||||
|
filepath: Optional[str] = None,
|
||||||
|
file_level: int = logging.DEBUG,
|
||||||
|
distributed_rank: Optional[int] = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Setups logger: name, level, format etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str, optional): new name for the logger. If None, the standard logger is used.
|
||||||
|
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||||
|
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||||
|
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||||
|
file_level (int): Optional logging level for logging file.
|
||||||
|
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||||
|
If None, distributed_rank is initialized to the rank of process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logging.Logger
|
||||||
|
|
||||||
|
For example, to improve logs readability when training with a trainer and evaluator:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ignite.utils import setup_logger
|
||||||
|
|
||||||
|
trainer = ...
|
||||||
|
evaluator = ...
|
||||||
|
|
||||||
|
trainer.logger = setup_logger("trainer")
|
||||||
|
evaluator.logger = setup_logger("evaluator")
|
||||||
|
|
||||||
|
trainer.run(data, max_epochs=10)
|
||||||
|
|
||||||
|
# Logs will look like
|
||||||
|
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
|
||||||
|
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
|
||||||
|
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
|
||||||
|
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
|
||||||
|
# ...
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# don't propagate to ancestors
|
||||||
|
# the problem here is to attach handlers to loggers
|
||||||
|
# should we provide a default configuration less open ?
|
||||||
|
if name is not None:
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
# Remove previous handlers
|
||||||
|
if logger.hasHandlers():
|
||||||
|
for h in list(logger.handlers):
|
||||||
|
logger.removeHandler(h)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(logger_format)
|
||||||
|
|
||||||
|
if distributed_rank is None:
|
||||||
|
import ignite.distributed as idist
|
||||||
|
|
||||||
|
distributed_rank = idist.get_rank()
|
||||||
|
|
||||||
|
if distributed_rank > 0:
|
||||||
|
logger.addHandler(logging.NullHandler())
|
||||||
|
else:
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
ch = logging.StreamHandler()
|
||||||
|
ch.setLevel(level)
|
||||||
|
ch.setFormatter(formatter)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
if filepath is not None:
|
||||||
|
fh = logging.FileHandler(filepath)
|
||||||
|
fh.setLevel(file_level)
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
return logger
|
||||||
@ -67,7 +67,11 @@ class _Registry:
|
|||||||
if default_args is not None:
|
if default_args is not None:
|
||||||
for name, value in default_args.items():
|
for name, value in default_args.items():
|
||||||
args.setdefault(name, value)
|
args.setdefault(name, value)
|
||||||
return obj_cls(**args)
|
try:
|
||||||
|
obj = obj_cls(**args)
|
||||||
|
except TypeError as e:
|
||||||
|
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
class ModuleRegistry(_Registry):
|
class ModuleRegistry(_Registry):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user