add U-GAT-IT
This commit is contained in:
parent
323bf2f6ab
commit
1a1cb9b00f
@ -25,7 +25,7 @@ baseline:
|
||||
_type: Adam
|
||||
data:
|
||||
dataloader:
|
||||
batch_size: 1024
|
||||
batch_size: 1200
|
||||
shuffle: True
|
||||
num_workers: 16
|
||||
pin_memory: True
|
||||
@ -37,7 +37,7 @@ baseline:
|
||||
pipeline:
|
||||
- Load
|
||||
- RandomResizedCrop:
|
||||
size: [256, 256]
|
||||
size: [224, 224]
|
||||
- ColorJitter:
|
||||
brightness: 0.4
|
||||
contrast: 0.4
|
||||
@ -47,20 +47,5 @@ baseline:
|
||||
- Normalize:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
val:
|
||||
path: /data/few-shot/mini_imagenet_full_size/val
|
||||
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
110
configs/synthesizers/UGATIT.yml
Normal file
110
configs/synthesizers/UGATIT.yml
Normal file
@ -0,0 +1,110 @@
|
||||
name: selfie2anime
|
||||
engine: UGATIT
|
||||
result_dir: ./result
|
||||
max_iteration: 100000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoints:
|
||||
interval: 1000
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: UGATIT-Generator
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 4
|
||||
img_size: 256
|
||||
light: True
|
||||
local_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 3
|
||||
global_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 5
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
weight: 1.0
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
cycle:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
id:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
cam:
|
||||
weight: 1000
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [0.5, 0.999]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [0.5, 0.999]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/selfie2anime/trainA"
|
||||
root_b: "/data/i2i/selfie2anime/trainB"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
scheduler:
|
||||
start: 50000
|
||||
target_lr: 0
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/selfie2anime/testA"
|
||||
root_b: "/data/i2i/selfie2anime/testB"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
@ -99,9 +99,9 @@ class EpisodicDataset(Dataset):
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set_list = []
|
||||
query_set_list = []
|
||||
target_list = []
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
@ -113,13 +113,13 @@ class EpisodicDataset(Dataset):
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set_list.extend(support)
|
||||
query_set_list.extend(query)
|
||||
target_list.extend([tag] * self.num_query)
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set_list),
|
||||
"query": torch.stack(query_set_list),
|
||||
"target": torch.tensor(target_list)
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
249
engine/UGATIT.py
Normal file
249
engine/UGATIT.py
Normal file
@ -0,0 +1,249 @@
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generators = dict(
|
||||
a2b=build_model(config.model.generator, config.distributed.model),
|
||||
b2a=build_model(config.model.generator, config.distributed.model),
|
||||
)
|
||||
discriminators = dict(
|
||||
la=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
lb=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
ga=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
gb=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
)
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
logger.debug(discriminators["ga"])
|
||||
logger.debug(generators["a2b"])
|
||||
|
||||
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
|
||||
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
|
||||
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||
|
||||
image_buffers = {
|
||||
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
|
||||
discriminators.keys()}
|
||||
|
||||
rho_clipper = RhoClipper(0, 1)
|
||||
|
||||
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||
discriminator_g):
|
||||
discriminator_g.requires_grad_(False)
|
||||
discriminator_l.requires_grad_(False)
|
||||
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||
pred_fake_l, cam_ld_pred = discriminator_l(fake)
|
||||
return {
|
||||
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
|
||||
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
|
||||
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
|
||||
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
|
||||
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
|
||||
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
|
||||
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||
}
|
||||
|
||||
def cal_discriminator_loss(name, discriminator, real, fake):
|
||||
pred_real, cam_real = discriminator(real)
|
||||
pred_fake, cam_fake = discriminator(fake)
|
||||
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
|
||||
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
|
||||
fake = dict()
|
||||
cam_generator_pred = dict()
|
||||
rec = dict()
|
||||
identity = dict()
|
||||
cam_identity_pred = dict()
|
||||
heatmap = dict()
|
||||
|
||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
|
||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
|
||||
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
for generator in generators.values():
|
||||
generator.apply(rho_clipper)
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizer_d.step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
|
||||
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||
},
|
||||
"img": {
|
||||
"heatmap": heatmap,
|
||||
"generated": generated_img
|
||||
}
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
|
||||
lr_scheduler_g=lr_scheduler_g)
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
|
||||
engine.state.iteration)
|
||||
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
|
||||
engine.state.iteration)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
@ -17,7 +17,7 @@ from data.transform import transform_pipeline
|
||||
from data.dataset import LMDBDataset
|
||||
|
||||
|
||||
def baseline_trainer(config, logger):
|
||||
def warmup_trainer(config, logger):
|
||||
model = build_model(config.model, config.distributed.model)
|
||||
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
@ -66,18 +66,20 @@ def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "baseline":
|
||||
if task == "warmup":
|
||||
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
||||
pipeline=config.baseline.data.dataset.train.pipeline)
|
||||
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
|
||||
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
||||
trainer = baseline_trainer(config, logger)
|
||||
trainer = warmup_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=400)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "protonet-wo":
|
||||
pass
|
||||
elif task == "protonet-w":
|
||||
pass
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
return ValueError(f"invalid task: {task}")
|
||||
|
||||
@ -18,7 +18,7 @@ from omegaconf import OmegaConf
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.residual_generator import GANImageBuffer
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers
|
||||
from util.build import build_model, build_optimizer
|
||||
@ -31,8 +31,8 @@ def get_trainer(config, logger):
|
||||
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
||||
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
||||
generation_init_weights(m)
|
||||
logger.debug(discriminator_a)
|
||||
logger.debug(generator_a)
|
||||
logger.info(discriminator_a)
|
||||
logger.info(generator_a)
|
||||
|
||||
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||
config.optimizers.generator)
|
||||
@ -56,8 +56,8 @@ def get_trainer(config, logger):
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
@ -93,11 +93,11 @@ def get_trainer(config, logger):
|
||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
||||
)
|
||||
(sum(loss_d_a.values()) * 0.5).backward()
|
||||
loss_d_b = dict(
|
||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
||||
)
|
||||
(sum(loss_d_a.values()) * 0.5).backward()
|
||||
(sum(loss_d_b.values()) * 0.5).backward()
|
||||
optimizer_d.step()
|
||||
|
||||
|
||||
9
engine/fewshot.py
Normal file
9
engine/fewshot.py
Normal file
@ -0,0 +1,9 @@
|
||||
from data.dataset import EpisodicDataset, LMDBDataset
|
||||
|
||||
|
||||
def prototypical_trainer(config, logger):
|
||||
pass
|
||||
|
||||
|
||||
def prototypical_dataloader(config):
|
||||
pass
|
||||
0
loss/fewshot/__init__.py
Normal file
0
loss/fewshot/__init__.py
Normal file
52
loss/fewshot/prototypical.py
Normal file
52
loss/fewshot/prototypical.py
Normal file
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PrototypicalLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def acc(query, target, support):
|
||||
prototypes = support.mean(-2) # batch_size x N_class x D
|
||||
distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
||||
indices = distance.argmin(-1) # smallest distance indices
|
||||
acc = torch.eq(target, indices).float().mean().item()
|
||||
return acc
|
||||
|
||||
@staticmethod
|
||||
def euclidean_dist(x, y):
|
||||
# x: B x N x D
|
||||
# y: B x M x D
|
||||
assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0)
|
||||
n = x.size(-2)
|
||||
m = y.size(-2)
|
||||
d = x.size(-1)
|
||||
x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D
|
||||
y = y.unsqueeze(1).expand(x.size(0), n, m, d)
|
||||
return torch.pow(x - y, 2).sum(-1) # B x N x M
|
||||
|
||||
def forward(self, query, target, support):
|
||||
"""
|
||||
calculate prototypical loss
|
||||
:param query: Tensor - batch_size x N_class*N_query x D
|
||||
:param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class)
|
||||
:param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id
|
||||
:return: loss item and accuracy
|
||||
"""
|
||||
|
||||
prototypes = support.mean(-2) # batch_size x N_class x D
|
||||
distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class
|
||||
indices = distance.argmin(-1) # smallest distance indices
|
||||
acc = torch.eq(target, indices).float().mean().item()
|
||||
|
||||
log_p_y = F.log_softmax(-distance, dim=-1)
|
||||
n_class = support.size(1)
|
||||
n_query = query.size(1) // n_class
|
||||
batch_size = support.size(0)
|
||||
target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp(
|
||||
n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1)
|
||||
loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean
|
||||
|
||||
return loss, acc
|
||||
16
main.py
16
main.py
@ -5,7 +5,8 @@ import torch
|
||||
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
from ignite.utils import manual_seed, setup_logger
|
||||
from ignite.utils import manual_seed
|
||||
from util.misc import setup_logger
|
||||
|
||||
import fire
|
||||
from omegaconf import OmegaConf
|
||||
@ -21,14 +22,12 @@ def log_basic_info(logger, config):
|
||||
|
||||
|
||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
|
||||
log_basic_info(logger, config)
|
||||
|
||||
if setup_random_seed:
|
||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||
if setup_output_dir:
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||
config.output_dir = str(output_dir)
|
||||
|
||||
if setup_output_dir and config.resume_from is None:
|
||||
if output_dir.exists():
|
||||
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||
contains = list(output_dir.iterdir())
|
||||
@ -37,11 +36,14 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
||||
else:
|
||||
if idist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True)
|
||||
logger.info(f"mkdir -p {output_dir}")
|
||||
logger.info(f"output path: {config.output_dir}")
|
||||
print(f"mkdir -p {output_dir}")
|
||||
|
||||
if backup_config and idist.get_rank() == 0:
|
||||
with open(output_dir / "config.yml", "w+") as f:
|
||||
print(config.pretty(), file=f)
|
||||
logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log")
|
||||
logger.info(f"output path: {config.output_dir}")
|
||||
log_basic_info(logger, config)
|
||||
|
||||
OmegaConf.set_readonly(config, True)
|
||||
|
||||
|
||||
253
model/GAN/UGATIT.py
Normal file
253
model/GAN/UGATIT.py
Normal file
@ -0,0 +1,253 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
|
||||
assert (num_blocks >= 0)
|
||||
super(Generator, self).__init__()
|
||||
self.input_channels = in_channels
|
||||
self.output_channels = out_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.img_size = img_size
|
||||
self.light = light
|
||||
|
||||
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.InstanceNorm2d(base_channels),
|
||||
nn.ReLU(True)]
|
||||
|
||||
n_down_sampling = 2
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** i
|
||||
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
|
||||
padding=1, bias=False, padding_mode="reflect"),
|
||||
nn.InstanceNorm2d(base_channels * mult * 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
# Down-Sampling Bottleneck
|
||||
mult = 2 ** n_down_sampling
|
||||
for i in range(num_blocks):
|
||||
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
|
||||
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
||||
self.down_encoder = nn.Sequential(*down_encoder)
|
||||
|
||||
# Class Activation Map
|
||||
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
# Gamma, Beta block
|
||||
if self.light:
|
||||
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
else:
|
||||
fc = [
|
||||
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
# Up-Sampling Bottleneck
|
||||
self.up_bottleneck = nn.ModuleList(
|
||||
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
|
||||
|
||||
# Up-Sampling
|
||||
up_decoder = []
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** (n_down_sampling - i)
|
||||
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
|
||||
padding=1, padding_mode="reflect", bias=False),
|
||||
ILN(base_channels * mult // 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.Tanh()]
|
||||
self.up_decoder = nn.Sequential(*up_decoder)
|
||||
# self.up_decoder = nn.ModuleDict({
|
||||
# "up_1": nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
# "up_conv_1": nn.Sequential(
|
||||
# nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1,
|
||||
# padding=1, padding_mode="reflect", bias=False),
|
||||
# ILN(base_channels * 4 // 2),
|
||||
# nn.ReLU(True)),
|
||||
# "up_2": nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
# "up_conv_2": nn.Sequential(
|
||||
# nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1,
|
||||
# padding=1, padding_mode="reflect", bias=False),
|
||||
# ILN(base_channels * 2 // 2),
|
||||
# nn.ReLU(True)),
|
||||
# "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
# padding_mode="reflect", bias=False), nn.Tanh())
|
||||
# })
|
||||
|
||||
def forward(self, x):
|
||||
x = self.down_encoder(x)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.relu(self.conv1x1(x))
|
||||
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
|
||||
if self.light:
|
||||
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||
else:
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for ub in self.up_bottleneck:
|
||||
x = ub(x, gamma, beta)
|
||||
|
||||
x = self.up_decoder(x)
|
||||
return x, cam_logit, heatmap
|
||||
|
||||
|
||||
class ResnetAdaILNBlock(nn.Module):
|
||||
def __init__(self, dim, use_bias):
|
||||
super(ResnetAdaILNBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm1 = AdaILN(dim)
|
||||
self.relu1 = nn.ReLU(True)
|
||||
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm2 = AdaILN(dim)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out, gamma, beta)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out, gamma, beta)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
|
||||
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
|
||||
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
class AdaILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||
super(AdaILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.rho.data.fill_(default_rho)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
|
||||
|
||||
|
||||
class ILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super(ILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.beta = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.rho.data.fill_(0.0)
|
||||
self.gamma.data.fill_(1.0)
|
||||
self.beta.data.fill_(0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return instance_layer_normalization(
|
||||
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=5):
|
||||
super(Discriminator, self).__init__()
|
||||
encoder = [self.build_conv_block(in_channels, base_channels)]
|
||||
|
||||
for i in range(1, num_blocks - 2):
|
||||
mult = 2 ** (i - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
|
||||
|
||||
mult = 2 ** (num_blocks - 2 - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
|
||||
|
||||
self.encoder = nn.Sequential(*encoder)
|
||||
|
||||
# Class Activation Map
|
||||
mult = 2 ** (num_blocks - 2)
|
||||
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.conv = nn.utils.spectral_norm(
|
||||
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
|
||||
|
||||
@staticmethod
|
||||
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
|
||||
return nn.Sequential(*[
|
||||
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
bias=True, padding=padding, padding_mode=padding_mode)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
])
|
||||
|
||||
def forward(self, x, return_heatmap=False):
|
||||
x = self.encoder(x)
|
||||
batch_size = x.size(0)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
|
||||
gap_logit = self.gap_fc(gap.view(batch_size, -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.leaky_relu(self.conv1x1(x))
|
||||
|
||||
if return_heatmap:
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
return self.conv(x), cam_logit, heatmap
|
||||
else:
|
||||
return self.conv(x), cam_logit
|
||||
0
model/GAN/__init__.py
Normal file
0
model/GAN/__init__.py
Normal file
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
from .registry import MODEL
|
||||
from model.registry import MODEL
|
||||
|
||||
|
||||
def _select_norm_layer(norm_type):
|
||||
@ -71,9 +71,10 @@ class GANImageBuffer(object):
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
@ -1,3 +1,3 @@
|
||||
from model.registry import MODEL
|
||||
import model.residual_generator
|
||||
import model.GAN.residual_generator
|
||||
import model.fewshot
|
||||
|
||||
8
run.sh
8
run.sh
@ -3,12 +3,18 @@
|
||||
CONFIG=$1
|
||||
TASK=$2
|
||||
GPUS=$3
|
||||
MORE_ARG=${*:4}
|
||||
|
||||
_command="print(len('${GPUS}'.split(',')))"
|
||||
GPU_COUNT=$(python3 -c "${_command}")
|
||||
|
||||
echo "GPU_COUNT:${GPU_COUNT}"
|
||||
|
||||
echo CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed
|
||||
main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ def setup_common_handlers(
|
||||
:param checkpoint_kwargs:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
@ -79,6 +80,8 @@ def setup_common_handlers(
|
||||
engine.logger.info(print_str)
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
|
||||
**checkpoint_kwargs)
|
||||
if resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
@ -89,5 +92,4 @@ def setup_common_handlers(
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
if save_interval_event is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
|
||||
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
||||
|
||||
85
util/misc.py
Normal file
85
util/misc.py
Normal file
@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: Optional[str] = None,
|
||||
level: int = logging.INFO,
|
||||
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
filepath: Optional[str] = None,
|
||||
file_level: int = logging.DEBUG,
|
||||
distributed_rank: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""Setups logger: name, level, format etc.
|
||||
|
||||
Args:
|
||||
name (str, optional): new name for the logger. If None, the standard logger is used.
|
||||
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||
file_level (int): Optional logging level for logging file.
|
||||
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||
If None, distributed_rank is initialized to the rank of process.
|
||||
|
||||
Returns:
|
||||
logging.Logger
|
||||
|
||||
For example, to improve logs readability when training with a trainer and evaluator:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
trainer = ...
|
||||
evaluator = ...
|
||||
|
||||
trainer.logger = setup_logger("trainer")
|
||||
evaluator.logger = setup_logger("evaluator")
|
||||
|
||||
trainer.run(data, max_epochs=10)
|
||||
|
||||
# Logs will look like
|
||||
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
|
||||
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
|
||||
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
|
||||
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
|
||||
# ...
|
||||
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# don't propagate to ancestors
|
||||
# the problem here is to attach handlers to loggers
|
||||
# should we provide a default configuration less open ?
|
||||
if name is not None:
|
||||
logger.propagate = False
|
||||
|
||||
# Remove previous handlers
|
||||
if logger.hasHandlers():
|
||||
for h in list(logger.handlers):
|
||||
logger.removeHandler(h)
|
||||
|
||||
formatter = logging.Formatter(logger_format)
|
||||
|
||||
if distributed_rank is None:
|
||||
import ignite.distributed as idist
|
||||
|
||||
distributed_rank = idist.get_rank()
|
||||
|
||||
if distributed_rank > 0:
|
||||
logger.addHandler(logging.NullHandler())
|
||||
else:
|
||||
logger.setLevel(level)
|
||||
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if filepath is not None:
|
||||
fh = logging.FileHandler(filepath)
|
||||
fh.setLevel(file_level)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
@ -67,7 +67,11 @@ class _Registry:
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_cls(**args)
|
||||
try:
|
||||
obj = obj_cls(**args)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
|
||||
return obj
|
||||
|
||||
|
||||
class ModuleRegistry(_Registry):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user