diff --git a/configs/few-shot/crossdomain.yml b/configs/few-shot/crossdomain.yml index b777b17..bafd121 100644 --- a/configs/few-shot/crossdomain.yml +++ b/configs/few-shot/crossdomain.yml @@ -25,7 +25,7 @@ baseline: _type: Adam data: dataloader: - batch_size: 1024 + batch_size: 1200 shuffle: True num_workers: 16 pin_memory: True @@ -37,7 +37,7 @@ baseline: pipeline: - Load - RandomResizedCrop: - size: [256, 256] + size: [224, 224] - ColorJitter: brightness: 0.4 contrast: 0.4 @@ -47,20 +47,5 @@ baseline: - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - val: - path: /data/few-shot/mini_imagenet_full_size/val - lmdb_path: /data/few-shot/lmdb/mini-ImageNet/val.lmdb - pipeline: - - Load - - Resize: - size: [286, 286] - - RandomCrop: - size: [256, 256] - - ToTensor - - Normalize: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - - diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml new file mode 100644 index 0000000..cee8eeb --- /dev/null +++ b/configs/synthesizers/UGATIT.yml @@ -0,0 +1,110 @@ +name: selfie2anime +engine: UGATIT +result_dir: ./result +max_iteration: 100000 + +distributed: + model: + # broadcast_buffers: False + +misc: + random_seed: 324 + +checkpoints: + interval: 1000 + +model: + generator: + _type: UGATIT-Generator + in_channels: 3 + out_channels: 3 + base_channels: 64 + num_blocks: 4 + img_size: 256 + light: True + local_discriminator: + _type: UGATIT-Discriminator + in_channels: 3 + base_channels: 64 + num_blocks: 3 + global_discriminator: + _type: UGATIT-Discriminator + in_channels: 3 + base_channels: 64 + num_blocks: 5 + +loss: + gan: + loss_type: lsgan + weight: 1.0 + real_label_val: 1.0 + fake_label_val: 0.0 + cycle: + level: 1 + weight: 10.0 + id: + level: 1 + weight: 10.0 + cam: + weight: 1000 + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [0.5, 0.999] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 1e-4 + betas: [0.5, 0.999] + weight_decay: 0.0001 + +data: + train: + buffer_size: 50 + dataloader: + batch_size: 8 + shuffle: True + num_workers: 2 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/selfie2anime/trainA" + root_b: "/data/i2i/selfie2anime/trainB" + random_pair: True + pipeline: + - Load + - Resize: + size: [286, 286] + - RandomCrop: + size: [256, 256] + - RandomHorizontalFlip + - ToTensor + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + scheduler: + start: 50000 + target_lr: 0 + test: + dataloader: + batch_size: 4 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDataset + root_a: "/data/i2i/selfie2anime/testA" + root_b: "/data/i2i/selfie2anime/testB" + random_pair: False + pipeline: + - Load + - Resize: + size: [256, 256] + - ToTensor + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] diff --git a/data/dataset.py b/data/dataset.py index 3cb9800..f8a57bc 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -99,9 +99,9 @@ class EpisodicDataset(Dataset): def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() - support_set_list = [] - query_set_list = [] - target_list = [] + support_set = [] + query_set = [] + target_set = [] for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] @@ -113,13 +113,13 @@ class EpisodicDataset(Dataset): support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:])) - support_set_list.extend(support) - query_set_list.extend(query) - target_list.extend([tag] * self.num_query) + support_set.extend(support) + query_set.extend(query) + target_set.extend([tag] * self.num_query) return { - "support": torch.stack(support_set_list), - "query": torch.stack(query_set_list), - "target": torch.tensor(target_list) + "support": torch.stack(support_set), + "query": torch.stack(query_set), + "target": torch.tensor(target_set) } def __repr__(self): diff --git a/engine/UGATIT.py b/engine/UGATIT.py new file mode 100644 index 0000000..007b0d7 --- /dev/null +++ b/engine/UGATIT.py @@ -0,0 +1,249 @@ +from itertools import chain +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.utils + +import ignite.distributed as idist +from ignite.engine import Events, Engine +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear +from ignite.metrics import RunningAverage +from ignite.contrib.handlers import ProgressBar +from ignite.utils import convert_tensor +from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler + +from omegaconf import OmegaConf + +import data +from loss.gan import GANLoss +from model.weight_init import generation_init_weights +from model.GAN.residual_generator import GANImageBuffer +from model.GAN.UGATIT import RhoClipper +from util.image import make_2d_grid +from util.handler import setup_common_handlers +from util.build import build_model, build_optimizer + + +def get_trainer(config, logger): + generators = dict( + a2b=build_model(config.model.generator, config.distributed.model), + b2a=build_model(config.model.generator, config.distributed.model), + ) + discriminators = dict( + la=build_model(config.model.local_discriminator, config.distributed.model), + lb=build_model(config.model.local_discriminator, config.distributed.model), + ga=build_model(config.model.global_discriminator, config.distributed.model), + gb=build_model(config.model.global_discriminator, config.distributed.model), + ) + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + logger.debug(discriminators["ga"]) + logger.debug(generators["a2b"]) + + optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator) + optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), + config.optimizers.discriminator) + + milestones_values = [ + (0, config.optimizers.generator.lr), + (config.data.train.scheduler.start, config.optimizers.generator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values) + + milestones_values = [ + (0, config.optimizers.discriminator.lr), + (config.data.train.scheduler.start, config.optimizers.discriminator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() + id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() + bce_loss = nn.BCEWithLogitsLoss().to(idist.device()) + mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size())) + bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size())) + + image_buffers = { + k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in + discriminators.keys()} + + rho_clipper = RhoClipper(0, 1) + + def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, + discriminator_g): + discriminator_g.requires_grad_(False) + discriminator_l.requires_grad_(False) + pred_fake_g, cam_gd_pred = discriminator_g(fake) + pred_fake_l, cam_ld_pred = discriminator_l(fake) + return { + f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec), + f"id_{name}": config.loss.id.weight * id_loss(real, identity), + f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)), + f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True), + f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True), + f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True), + f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True), + } + + def cal_discriminator_loss(name, discriminator, real, fake): + pred_real, cam_real = discriminator(real) + pred_fake, cam_fake = discriminator(fake) + # TODO: origin do not divide 2, but I think it better to divide 2. + loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True) + loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False) + return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam} + + def _step(engine, batch): + batch = convert_tensor(batch, idist.device()) + real_a, real_b = batch["a"], batch["b"] + + fake = dict() + cam_generator_pred = dict() + rec = dict() + identity = dict() + cam_identity_pred = dict() + heatmap = dict() + + fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a) + fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b) + rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"]) + rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"]) + identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a) + identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b) + + optimizer_g.zero_grad() + loss_g = dict() + for n in ["a", "b"]: + loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n], + cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) + sum(loss_g.values()).backward() + optimizer_g.step() + for generator in generators.values(): + generator.apply(rho_clipper) + for discriminator in discriminators.values(): + discriminator.requires_grad_(True) + + optimizer_d.zero_grad() + loss_d = dict() + for k in discriminators.keys(): + n = k[-1] # "a" or "b" + loss_d.update( + cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach()))) + sum(loss_d.values()).backward() + optimizer_d.step() + + for h in heatmap: + heatmap[h] = heatmap[h].detach() + generated_img = {f"fake_{k}": fake[k].detach() for k in fake} + generated_img.update({f"id_{k}": identity[k].detach() for k in identity}) + generated_img.update({f"rec_{k}": rec[k].detach() for k in rec}) + + return { + "loss": { + "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, + "d": {ln: loss_d[ln].mean().item() for ln in loss_d}, + }, + "img": { + "heatmap": heatmap, + "generated": generated_img + } + } + + trainer = Engine(_step) + trainer.logger = logger + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g) + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d) + + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") + + to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d, + lr_scheduler_g=lr_scheduler_g) + to_save.update({f"generator_{k}": generators[k] for k in generators}) + to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) + + setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5, + filename_prefix=config.name, to_save=to_save, + print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED, + metrics_to_print=["loss_g", "loss_d"], + save_interval_event=Events.ITERATION_COMPLETED( + every=config.checkpoints.interval) | Events.COMPLETED) + + @trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration)) + def terminate(engine): + engine.terminate() + + if idist.get_rank() == 0: + # Create a logger + tb_logger = TensorboardLogger(log_dir=config.output_dir) + tb_writer = tb_logger.writer + + # Attach the logger to the trainer to log training loss at each iteration + def global_step_transform(*args, **kwargs): + return trainer.state.iteration + + def output_transform(output): + loss = dict() + for tl in output["loss"]: + if isinstance(output["loss"][tl], dict): + for l in output["loss"][tl]: + loss[f"{tl}_{l}"] = output["loss"][tl][l] + else: + loss[tl] = output["loss"][tl] + return loss + + tb_logger.attach( + trainer, + log_handler=OutputHandler( + tag="loss", + metric_names=["loss_g", "loss_d"], + global_step_transform=global_step_transform, + output_transform=output_transform + ), + event_name=Events.ITERATION_COMPLETED(every=50) + ) + tb_logger.attach( + trainer, + log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"), + event_name=Events.ITERATION_STARTED(every=50) + ) + + @trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval)) + def show_images(engine): + tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()), + engine.state.iteration) + tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()), + engine.state.iteration) + + @trainer.on(Events.COMPLETED) + @idist.one_rank_only() + def _(): + # We need to close the logger with we are done + tb_logger.close() + + return trainer + + +def run(task, config, logger): + assert torch.backends.cudnn.enabled + torch.backends.cudnn.benchmark = True + logger.info(f"start task {task}") + if task == "train": + train_dataset = data.DATASET.build_with(config.data.train.dataset) + logger.info(f"train with dataset:\n{train_dataset}") + train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) + trainer = get_trainer(config, logger) + try: + trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) + except Exception: + import traceback + print(traceback.format_exc()) + else: + return NotImplemented(f"invalid task: {task}") diff --git a/engine/crossdomain.py b/engine/crossdomain.py index 1e9bf57..cf732d3 100644 --- a/engine/crossdomain.py +++ b/engine/crossdomain.py @@ -17,7 +17,7 @@ from data.transform import transform_pipeline from data.dataset import LMDBDataset -def baseline_trainer(config, logger): +def warmup_trainer(config, logger): model = build_model(config.model, config.distributed.model) optimizer = build_optimizer(model.parameters(), config.baseline.optimizers) loss_fn = nn.CrossEntropyLoss() @@ -66,18 +66,20 @@ def run(task, config, logger): assert torch.backends.cudnn.enabled torch.backends.cudnn.benchmark = True logger.info(f"start task {task}") - if task == "baseline": + if task == "warmup": train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path, pipeline=config.baseline.data.dataset.train.pipeline) - # train_dataset = ImageFolder(config.baseline.data.dataset.train.path, - # transform=transform_pipeline(config.baseline.data.dataset.train.pipeline)) logger.info(f"train with dataset:\n{train_dataset}") train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader) - trainer = baseline_trainer(config, logger) + trainer = warmup_trainer(config, logger) try: trainer.run(train_data_loader, max_epochs=400) except Exception: import traceback print(traceback.format_exc()) + elif task == "protonet-wo": + pass + elif task == "protonet-w": + pass else: - return NotImplemented(f"invalid task: {task}") + return ValueError(f"invalid task: {task}") diff --git a/engine/cyclegan.py b/engine/cyclegan.py index aac8aee..2e56d52 100644 --- a/engine/cyclegan.py +++ b/engine/cyclegan.py @@ -18,7 +18,7 @@ from omegaconf import OmegaConf import data from loss.gan import GANLoss from model.weight_init import generation_init_weights -from model.residual_generator import GANImageBuffer +from model.GAN.residual_generator import GANImageBuffer from util.image import make_2d_grid from util.handler import setup_common_handlers from util.build import build_model, build_optimizer @@ -31,8 +31,8 @@ def get_trainer(config, logger): discriminator_b = build_model(config.model.discriminator, config.distributed.model) for m in [generator_b, generator_a, discriminator_b, discriminator_a]: generation_init_weights(m) - logger.debug(discriminator_a) - logger.debug(generator_a) + logger.info(discriminator_a) + logger.info(generator_a) optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()), config.optimizers.generator) @@ -56,8 +56,8 @@ def get_trainer(config, logger): gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() - id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss() + cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() + id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) @@ -93,11 +93,11 @@ def get_trainer(config, logger): real=gan_loss(discriminator_a(real_b), True, is_discriminator=True), fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True), ) - (sum(loss_d_a.values()) * 0.5).backward() loss_d_b = dict( real=gan_loss(discriminator_b(real_a), True, is_discriminator=True), fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True), ) + (sum(loss_d_a.values()) * 0.5).backward() (sum(loss_d_b.values()) * 0.5).backward() optimizer_d.step() diff --git a/engine/fewshot.py b/engine/fewshot.py new file mode 100644 index 0000000..b449c6d --- /dev/null +++ b/engine/fewshot.py @@ -0,0 +1,9 @@ +from data.dataset import EpisodicDataset, LMDBDataset + + +def prototypical_trainer(config, logger): + pass + + +def prototypical_dataloader(config): + pass diff --git a/loss/fewshot/__init__.py b/loss/fewshot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/fewshot/prototypical.py b/loss/fewshot/prototypical.py new file mode 100644 index 0000000..604309b --- /dev/null +++ b/loss/fewshot/prototypical.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PrototypicalLoss(nn.Module): + def __init__(self): + super().__init__() + + @staticmethod + def acc(query, target, support): + prototypes = support.mean(-2) # batch_size x N_class x D + distance = PrototypicalLoss.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class + indices = distance.argmin(-1) # smallest distance indices + acc = torch.eq(target, indices).float().mean().item() + return acc + + @staticmethod + def euclidean_dist(x, y): + # x: B x N x D + # y: B x M x D + assert x.size(-1) == y.size(-1) and x.size(0) == y.size(0) + n = x.size(-2) + m = y.size(-2) + d = x.size(-1) + x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D + y = y.unsqueeze(1).expand(x.size(0), n, m, d) + return torch.pow(x - y, 2).sum(-1) # B x N x M + + def forward(self, query, target, support): + """ + calculate prototypical loss + :param query: Tensor - batch_size x N_class*N_query x D + :param target: Tensor - batch_size x N_class*N_query, target id set, value must in [0, N_class) + :param support: Tensor - batch_size x N_class x N_support x D, must be ordered by class id + :return: loss item and accuracy + """ + + prototypes = support.mean(-2) # batch_size x N_class x D + distance = self.euclidean_dist(query, prototypes) # batch_size x N_class*N_query x N_class + indices = distance.argmin(-1) # smallest distance indices + acc = torch.eq(target, indices).float().mean().item() + + log_p_y = F.log_softmax(-distance, dim=-1) + n_class = support.size(1) + n_query = query.size(1) // n_class + batch_size = support.size(0) + target_log_indices = torch.arange(n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).resharp( + n_class * n_query, 1).view(1, n_class * n_query, 1).expand(batch_size, n_class * n_query, 1) + loss = -log_p_y.gather(2, target_log_indices).mean() # select log-probability of true class then get the mean + + return loss, acc diff --git a/main.py b/main.py index 05bcf5c..9dff274 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ import torch import ignite import ignite.distributed as idist -from ignite.utils import manual_seed, setup_logger +from ignite.utils import manual_seed +from util.misc import setup_logger import fire from omegaconf import OmegaConf @@ -21,14 +22,12 @@ def log_basic_info(logger, config): def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False): - logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger) - log_basic_info(logger, config) - if setup_random_seed: manual_seed(config.misc.random_seed + idist.get_rank()) - if setup_output_dir: - output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir - config.output_dir = str(output_dir) + output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir + config.output_dir = str(output_dir) + + if setup_output_dir and config.resume_from is None: if output_dir.exists(): # assert not any(output_dir.iterdir()), "output_dir must be empty" contains = list(output_dir.iterdir()) @@ -37,11 +36,14 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals else: if idist.get_rank() == 0: output_dir.mkdir(parents=True) - logger.info(f"mkdir -p {output_dir}") - logger.info(f"output path: {config.output_dir}") + print(f"mkdir -p {output_dir}") + if backup_config and idist.get_rank() == 0: with open(output_dir / "config.yml", "w+") as f: print(config.pretty(), file=f) + logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log") + logger.info(f"output path: {config.output_dir}") + log_basic_info(logger, config) OmegaConf.set_readonly(config, True) diff --git a/model/GAN/UGATIT.py b/model/GAN/UGATIT.py new file mode 100644 index 0000000..e02c1cb --- /dev/null +++ b/model/GAN/UGATIT.py @@ -0,0 +1,253 @@ +import torch +import torch.nn as nn +from .residual_generator import ResidualBlock +from model.registry import MODEL + + +class RhoClipper(object): + def __init__(self, clip_min, clip_max): + self.clip_min = clip_min + self.clip_max = clip_max + assert clip_min < clip_max + + def __call__(self, module): + if hasattr(module, 'rho'): + w = module.rho.data + w = w.clamp(self.clip_min, self.clip_max) + module.rho.data = w + + +@MODEL.register_module("UGATIT-Generator") +class Generator(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False): + assert (num_blocks >= 0) + super(Generator, self).__init__() + self.input_channels = in_channels + self.output_channels = out_channels + self.base_channels = base_channels + self.num_blocks = num_blocks + self.img_size = img_size + self.light = light + + down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, + padding_mode="reflect", bias=False), + nn.InstanceNorm2d(base_channels), + nn.ReLU(True)] + + n_down_sampling = 2 + for i in range(n_down_sampling): + mult = 2 ** i + down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2, + padding=1, bias=False, padding_mode="reflect"), + nn.InstanceNorm2d(base_channels * mult * 2), + nn.ReLU(True)] + + # Down-Sampling Bottleneck + mult = 2 ** n_down_sampling + for i in range(num_blocks): + # TODO: change ResnetBlock to ResidualBlock, check use_bias param + down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)] + self.down_encoder = nn.Sequential(*down_encoder) + + # Class Activation Map + self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False) + self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False) + self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) + self.relu = nn.ReLU(True) + + # Gamma, Beta block + if self.light: + fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False), + nn.ReLU(True), + nn.Linear(base_channels * mult, base_channels * mult, bias=False), + nn.ReLU(True)] + else: + fc = [ + nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False), + nn.ReLU(True), + nn.Linear(base_channels * mult, base_channels * mult, bias=False), + nn.ReLU(True)] + self.fc = nn.Sequential(*fc) + + self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False) + self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False) + + # Up-Sampling Bottleneck + self.up_bottleneck = nn.ModuleList( + [ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)]) + + # Up-Sampling + up_decoder = [] + for i in range(n_down_sampling): + mult = 2 ** (n_down_sampling - i) + up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1, + padding=1, padding_mode="reflect", bias=False), + ILN(base_channels * mult // 2), + nn.ReLU(True)] + + up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, + padding_mode="reflect", bias=False), + nn.Tanh()] + self.up_decoder = nn.Sequential(*up_decoder) + # self.up_decoder = nn.ModuleDict({ + # "up_1": nn.Upsample(scale_factor=2, mode='nearest'), + # "up_conv_1": nn.Sequential( + # nn.Conv2d(base_channels * 4, base_channels * 4 // 2, kernel_size=3, stride=1, + # padding=1, padding_mode="reflect", bias=False), + # ILN(base_channels * 4 // 2), + # nn.ReLU(True)), + # "up_2": nn.Upsample(scale_factor=2, mode='nearest'), + # "up_conv_2": nn.Sequential( + # nn.Conv2d(base_channels * 2, base_channels * 2 // 2, kernel_size=3, stride=1, + # padding=1, padding_mode="reflect", bias=False), + # ILN(base_channels * 2 // 2), + # nn.ReLU(True)), + # "up_end": nn.Sequential(nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3, + # padding_mode="reflect", bias=False), nn.Tanh()) + # }) + + def forward(self, x): + x = self.down_encoder(x) + + gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) + gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) + gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) + + gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) + gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) + gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) + + cam_logit = torch.cat([gap_logit, gmp_logit], 1) + + x = torch.cat([gap, gmp], 1) + x = self.relu(self.conv1x1(x)) + + heatmap = torch.sum(x, dim=1, keepdim=True) + + if self.light: + x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) + x_ = self.fc(x_.view(x_.shape[0], -1)) + else: + x_ = self.fc(x.view(x.shape[0], -1)) + gamma, beta = self.gamma(x_), self.beta(x_) + + for ub in self.up_bottleneck: + x = ub(x, gamma, beta) + + x = self.up_decoder(x) + return x, cam_logit, heatmap + + +class ResnetAdaILNBlock(nn.Module): + def __init__(self, dim, use_bias): + super(ResnetAdaILNBlock, self).__init__() + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") + self.norm1 = AdaILN(dim) + self.relu1 = nn.ReLU(True) + + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect") + self.norm2 = AdaILN(dim) + + def forward(self, x, gamma, beta): + out = self.conv1(x) + out = self.norm1(out, gamma, beta) + out = self.relu1(out) + out = self.conv2(out) + out = self.norm2(out, gamma, beta) + + return out + x + + +def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5): + in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) + out_in = (x - in_mean) / torch.sqrt(in_var + eps) + ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) + out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps) + out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln + out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) + return out + + +class AdaILN(nn.Module): + def __init__(self, num_features, eps=1e-5, default_rho=0.9): + super(AdaILN, self).__init__() + self.eps = eps + self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + self.rho.data.fill_(default_rho) + + def forward(self, x, gamma, beta): + return instance_layer_normalization(x, gamma, beta, self.rho, self.eps) + + +class ILN(nn.Module): + def __init__(self, num_features, eps=1e-5): + super(ILN, self).__init__() + self.eps = eps + self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + self.gamma = nn.Parameter(torch.Tensor(1, num_features)) + self.beta = nn.Parameter(torch.Tensor(1, num_features)) + self.rho.data.fill_(0.0) + self.gamma.data.fill_(1.0) + self.beta.data.fill_(0.0) + + def forward(self, x): + return instance_layer_normalization( + x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps) + + +@MODEL.register_module("UGATIT-Discriminator") +class Discriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, num_blocks=5): + super(Discriminator, self).__init__() + encoder = [self.build_conv_block(in_channels, base_channels)] + + for i in range(1, num_blocks - 2): + mult = 2 ** (i - 1) + encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2)) + + mult = 2 ** (num_blocks - 2 - 1) + encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1)) + + self.encoder = nn.Sequential(*encoder) + + # Class Activation Map + mult = 2 ** (num_blocks - 2) + self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) + self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False)) + self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True) + self.leaky_relu = nn.LeakyReLU(0.2, True) + + self.conv = nn.utils.spectral_norm( + nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect")) + + @staticmethod + def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"): + return nn.Sequential(*[ + nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + bias=True, padding=padding, padding_mode=padding_mode)), + nn.LeakyReLU(0.2, True), + ]) + + def forward(self, x, return_heatmap=False): + x = self.encoder(x) + batch_size = x.size(0) + + gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel + gap_logit = self.gap_fc(gap.view(batch_size, -1)) + gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3) + + gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) + gmp_logit = self.gmp_fc(gmp.view(batch_size, -1)) + gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3) + + cam_logit = torch.cat([gap_logit, gmp_logit], 1) + + x = torch.cat([gap, gmp], 1) + x = self.leaky_relu(self.conv1x1(x)) + + if return_heatmap: + heatmap = torch.sum(x, dim=1, keepdim=True) + return self.conv(x), cam_logit, heatmap + else: + return self.conv(x), cam_logit diff --git a/model/GAN/__init__.py b/model/GAN/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/residual_generator.py b/model/GAN/residual_generator.py similarity index 97% rename from model/residual_generator.py rename to model/GAN/residual_generator.py index c3138be..9c4adac 100644 --- a/model/residual_generator.py +++ b/model/GAN/residual_generator.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import functools -from .registry import MODEL +from model.registry import MODEL def _select_norm_layer(norm_type): @@ -71,11 +71,12 @@ class GANImageBuffer(object): @MODEL.register_module() class ResidualBlock(nn.Module): - def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False): + def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None): super(ResidualBlock, self).__init__() - # Only for IN, use bias since it does not have affine parameters. - use_bias = norm_type == "IN" + if use_bias is None: + # Only for IN, use bias since it does not have affine parameters. + use_bias = norm_type == "IN" norm_layer = _select_norm_layer(norm_type) models = [nn.Sequential( nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), diff --git a/model/__init__.py b/model/__init__.py index 328df68..cfbd292 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,3 +1,3 @@ from model.registry import MODEL -import model.residual_generator +import model.GAN.residual_generator import model.fewshot diff --git a/run.sh b/run.sh index 21127a1..02ab01c 100644 --- a/run.sh +++ b/run.sh @@ -3,12 +3,18 @@ CONFIG=$1 TASK=$2 GPUS=$3 +MORE_ARG=${*:4} _command="print(len('${GPUS}'.split(',')))" GPU_COUNT=$(python3 -c "${_command}") echo "GPU_COUNT:${GPU_COUNT}" +echo CUDA_VISIBLE_DEVICES=$GPUS \ +PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ + main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG" + CUDA_VISIBLE_DEVICES=$GPUS \ PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \ - main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed + main.py "$TASK" "$CONFIG" "$MORE_ARG" --backup_config --setup_output_dir --setup_random_seed + diff --git a/util/handler.py b/util/handler.py index 180b872..15fd41f 100644 --- a/util/handler.py +++ b/util/handler.py @@ -39,6 +39,7 @@ def setup_common_handlers( :param checkpoint_kwargs: :return: """ + @trainer.on(Events.STARTED) @idist.one_rank_only() def print_dataloader_size(engine): @@ -79,6 +80,8 @@ def setup_common_handlers( engine.logger.info(print_str) if to_save is not None: + checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False), + **checkpoint_kwargs) if resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): @@ -89,5 +92,4 @@ def setup_common_handlers( Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") if save_interval_event is not None: - checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs) trainer.add_event_handler(save_interval_event, checkpoint_handler) diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000..eac66f3 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,85 @@ +import logging +from typing import Optional + + +def setup_logger( + name: Optional[str] = None, + level: int = logging.INFO, + logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s", + filepath: Optional[str] = None, + file_level: int = logging.DEBUG, + distributed_rank: Optional[int] = None, +) -> logging.Logger: + """Setups logger: name, level, format etc. + + Args: + name (str, optional): new name for the logger. If None, the standard logger is used. + level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG + logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s` + filepath (str, optional): Optional logging file path. If not None, logs are written to the file. + file_level (int): Optional logging level for logging file. + distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers. + If None, distributed_rank is initialized to the rank of process. + + Returns: + logging.Logger + + For example, to improve logs readability when training with a trainer and evaluator: + + .. code-block:: python + + from ignite.utils import setup_logger + + trainer = ... + evaluator = ... + + trainer.logger = setup_logger("trainer") + evaluator.logger = setup_logger("evaluator") + + trainer.run(data, max_epochs=10) + + # Logs will look like + # 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5. + # 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23 + # 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1. + # 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02 + # ... + + """ + logger = logging.getLogger(name) + + # don't propagate to ancestors + # the problem here is to attach handlers to loggers + # should we provide a default configuration less open ? + if name is not None: + logger.propagate = False + + # Remove previous handlers + if logger.hasHandlers(): + for h in list(logger.handlers): + logger.removeHandler(h) + + formatter = logging.Formatter(logger_format) + + if distributed_rank is None: + import ignite.distributed as idist + + distributed_rank = idist.get_rank() + + if distributed_rank > 0: + logger.addHandler(logging.NullHandler()) + else: + logger.setLevel(level) + + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if filepath is not None: + fh = logging.FileHandler(filepath) + fh.setLevel(file_level) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/util/registry.py b/util/registry.py index 6b59af1..6fd4a75 100644 --- a/util/registry.py +++ b/util/registry.py @@ -67,7 +67,11 @@ class _Registry: if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) - return obj_cls(**args) + try: + obj = obj_cls(**args) + except TypeError as e: + raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e + return obj class ModuleRegistry(_Registry):