diff --git a/.idea/other.xml b/.idea/other.xml
new file mode 100644
index 0000000..8339243
--- /dev/null
+++ b/.idea/other.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/configs/synthesizers/TAHG.yml b/configs/synthesizers/TAHG.yml
new file mode 100644
index 0000000..3395036
--- /dev/null
+++ b/configs/synthesizers/TAHG.yml
@@ -0,0 +1,127 @@
+name: TAHG
+engine: TAHG
+result_dir: ./result
+max_pairs: 1000000
+
+distributed:
+ model:
+ # broadcast_buffers: False
+
+misc:
+ random_seed: 324
+
+checkpoint:
+ epoch_interval: 1 # one checkpoint every 1 epoch
+ n_saved: 2
+
+interval:
+ print_per_iteration: 10 # print once per 10 iteration
+ tensorboard:
+ scalar: 10
+ image: 500
+
+model:
+ generator:
+ _type: TAHG-Generator
+ style_in_channels: 3
+ content_in_channels: 23
+ discriminator:
+ _type: TAHG-Discriminator
+ in_channels: 3
+
+loss:
+ gan:
+ loss_type: lsgan
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ weight: 1.0
+ edge:
+ criterion: 'L1'
+ hed_pretrained_model_path: "/root/network-bsds500.pytorch"
+ weight: 2
+ perceptual:
+ layer_weights:
+# "3": 1.0
+ "0": 1.0
+ "5": 1.0
+ "10": 1.0
+ "19": 1.0
+ criterion: 'L2'
+ style_loss: True
+ perceptual_loss: False
+ weight: 100.0
+ recon:
+ level: 1
+ weight: 2
+
+optimizers:
+ generator:
+ _type: Adam
+ lr: 0.0001
+ betas: [ 0.5, 0.999 ]
+ weight_decay: 0.0001
+ discriminator:
+ _type: Adam
+ lr: 1e-4
+ betas: [ 0.5, 0.999 ]
+ weight_decay: 0.0001
+
+data:
+ train:
+ scheduler:
+ start_proportion: 0.5
+ target_lr: 0
+ buffer_size: 50
+ dataloader:
+ batch_size: 4
+ shuffle: True
+ num_workers: 2
+ pin_memory: True
+ drop_last: True
+ dataset:
+ _type: GenerationUnpairedDatasetWithEdge
+ root_a: "/data/i2i/VoxCeleb2Anime/trainA"
+ root_b: "/data/i2i/VoxCeleb2Anime/trainB"
+ edge_type: "hed_landmark"
+ random_pair: True
+ pipeline:
+ - Load
+ - Resize:
+ size: [ 256, 256 ]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ test:
+ dataloader:
+ batch_size: 8
+ shuffle: False
+ num_workers: 1
+ pin_memory: False
+ drop_last: False
+ dataset:
+ _type: GenerationUnpairedDatasetWithEdge
+ root_a: "/data/i2i/VoxCeleb2Anime/testA"
+ root_b: "/data/i2i/VoxCeleb2Anime/testB"
+ edge_type: "hed_landmark"
+ random_pair: False
+ pipeline:
+ - Load
+ - Resize:
+ size: [ 256, 256 ]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
+ video_dataset:
+ _type: SingleFolderDataset
+ root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
+ with_path: True
+ pipeline:
+ - Load
+ - Resize:
+ size: [ 256, 256 ]
+ - ToTensor
+ - Normalize:
+ mean: [ 0.5, 0.5, 0.5 ]
+ std: [ 0.5, 0.5, 0.5 ]
diff --git a/data/dataset.py b/data/dataset.py
index 37f19a9..5cb9462 100644
--- a/data/dataset.py
+++ b/data/dataset.py
@@ -1,5 +1,6 @@
import os
import pickle
+from pathlib import Path
from collections import defaultdict
import torch
@@ -171,3 +172,33 @@ class GenerationUnpairedDataset(Dataset):
def __repr__(self):
return f"\nPipeline:\n{self.A.pipeline}"
+
+
+@DATASET.register_module()
+class GenerationUnpairedDatasetWithEdge(Dataset):
+ def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
+ self.edge_type = edge_type
+ self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
+ self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
+ self.random_pair = random_pair
+
+ def get_edge(self, origin_path):
+ op = Path(origin_path)
+ add = torch.load(op.parent / f"{op.stem}.add")
+ return {"edge": add["edge"].float().unsqueeze(dim=0),
+ "additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
+
+ def __getitem__(self, idx):
+ a_idx = idx % len(self.A)
+ b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
+ output = dict()
+ output["a"], path_a = self.A[a_idx]
+ output.update(self.get_edge(path_a))
+ output["b"] = self.B[b_idx]
+ return output
+
+ def __len__(self):
+ return max(len(self.A), len(self.B))
+
+ def __repr__(self):
+ return f"\nPipeline:\n{self.A.pipeline}"
diff --git a/engine/TAHG.py b/engine/TAHG.py
new file mode 100644
index 0000000..660aca3
--- /dev/null
+++ b/engine/TAHG.py
@@ -0,0 +1,204 @@
+from itertools import chain
+from math import ceil
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+import ignite.distributed as idist
+from ignite.engine import Events, Engine
+from ignite.metrics import RunningAverage
+from ignite.utils import convert_tensor
+from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
+from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
+
+from omegaconf import OmegaConf, read_write
+
+import data
+from loss.gan import GANLoss
+from model.weight_init import generation_init_weights
+from model.GAN.residual_generator import GANImageBuffer
+from loss.I2I.edge_loss import EdgeLoss
+from loss.I2I.perceptual_loss import PerceptualLoss
+from util.image import make_2d_grid
+from util.handler import setup_common_handlers, setup_tensorboard_handler
+from util.build import build_model, build_optimizer
+
+
+def build_lr_schedulers(optimizers, config):
+ g_milestones_values = [
+ (0, config.optimizers.generator.lr),
+ (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
+ (config.max_iteration, config.data.train.scheduler.target_lr)
+ ]
+ d_milestones_values = [
+ (0, config.optimizers.discriminator.lr),
+ (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
+ (config.max_iteration, config.data.train.scheduler.target_lr)
+ ]
+ return dict(
+ g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
+ d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
+ )
+
+
+def get_trainer(config, logger):
+ generator = build_model(config.model.generator, config.distributed.model)
+ discriminators = dict(
+ a=build_model(config.model.discriminator, config.distributed.model),
+ b=build_model(config.model.discriminator, config.distributed.model),
+ )
+ generation_init_weights(generator)
+ for m in discriminators.values():
+ generation_init_weights(m)
+
+ logger.debug(discriminators["a"])
+ logger.debug(generator)
+
+ optimizers = dict(
+ g=build_optimizer(generator.parameters(), config.optimizers.generator),
+ d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
+ )
+ logger.info(f"build optimizers:\n{optimizers}")
+
+ lr_schedulers = build_lr_schedulers(optimizers, config)
+ logger.info(f"build lr_schedulers:\n{lr_schedulers}")
+
+ gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
+ gan_loss_cfg.pop("weight")
+ gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
+
+ edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
+ edge_loss_cfg.pop("weight")
+ edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
+
+ perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
+ perceptual_loss_cfg.pop("weight")
+ perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
+
+ recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
+
+ image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
+
+ def _step(engine, batch):
+ batch = convert_tensor(batch, idist.device())
+ real = dict(a=batch["a"], b=batch["b"])
+ edge = batch["edge"]
+ additional_info = batch["additional_info"]
+ content_img = torch.cat([edge, additional_info], dim=1)
+ fake = dict(
+ a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
+ b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
+ )
+
+ optimizers["g"].zero_grad()
+ loss_g = dict()
+ for d in "ab":
+ discriminators[d].requires_grad_(False)
+ pred_fake = discriminators[d](fake[d])
+ loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
+ _, t = perceptual_loss(fake[d], real[d])
+ loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
+ loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
+ loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
+ sum(loss_g.values()).backward()
+ optimizers["g"].step()
+
+ for discriminator in discriminators.values():
+ discriminator.requires_grad_(True)
+
+ optimizers["d"].zero_grad()
+ loss_d = dict()
+ for k in discriminators.keys():
+ pred_real = discriminators[k](real[k])
+ pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
+ loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
+ gan_loss(pred_fake, False, is_discriminator=True)) / 2
+ sum(loss_d.values()).backward()
+ optimizers["d"].step()
+
+ generated_img = {f"real_{k}": real[k].detach() for k in real}
+ generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
+ return {
+ "loss": {
+ "g": {ln: loss_g[ln].mean().item() for ln in loss_g},
+ "d": {ln: loss_d[ln].mean().item() for ln in loss_d},
+ },
+ "img": generated_img
+ }
+
+ trainer = Engine(_step)
+ trainer.logger = logger
+ for lr_shd in lr_schedulers.values():
+ trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
+
+ RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
+ RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
+
+ to_save = dict(trainer=trainer)
+ to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
+ to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
+ to_save.update({"generator": generator})
+ to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
+ setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
+ end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
+
+ def output_transform(output):
+ loss = dict()
+ for tl in output["loss"]:
+ if isinstance(output["loss"][tl], dict):
+ for l in output["loss"][tl]:
+ loss[f"{tl}_{l}"] = output["loss"][tl][l]
+ else:
+ loss[tl] = output["loss"][tl]
+ return loss
+
+ tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
+ if tensorboard_handler is not None:
+ tensorboard_handler.attach(
+ trainer,
+ log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
+ event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
+ )
+
+ @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
+ def show_images(engine):
+ output = engine.state.output
+ image_order = dict(
+ a=["real_a", "fake_a"],
+ b=["real_b", "fake_b"]
+ )
+ for k in "ab":
+ tensorboard_handler.writer.add_image(
+ f"train/{k}",
+ make_2d_grid([output["img"][o] for o in image_order[k]]),
+ engine.state.iteration
+ )
+
+ return trainer
+
+
+def run(task, config, logger):
+ assert torch.backends.cudnn.enabled
+ torch.backends.cudnn.benchmark = True
+ logger.info(f"start task {task}")
+ with read_write(config):
+ config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
+
+ if task == "train":
+ train_dataset = data.DATASET.build_with(config.data.train.dataset)
+ logger.info(f"train with dataset:\n{train_dataset}")
+ train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
+ trainer = get_trainer(config, logger)
+ if idist.get_rank() == 0:
+ test_dataset = data.DATASET.build_with(config.data.test.dataset)
+ trainer.state.test_dataset = test_dataset
+ try:
+ trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
+ except Exception:
+ import traceback
+ print(traceback.format_exc())
+ else:
+ return NotImplemented(f"invalid task: {task}")
diff --git a/loss/I2I/__init__.py b/loss/I2I/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/loss/I2I/edge_loss.py b/loss/I2I/edge_loss.py
new file mode 100644
index 0000000..a10f751
--- /dev/null
+++ b/loss/I2I/edge_loss.py
@@ -0,0 +1,129 @@
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class HED(nn.Module):
+ def __init__(self, pretrained_model_path, norm_img=True):
+ """
+ HED module to get edge
+ :param pretrained_model_path: path to pretrained HED.
+ :param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
+ this is different from the `use_input_norm` which norm the input in
+ in forward function of vgg according to the statistics of dataset.
+ Importantly, the input image must be in range [-1, 1].
+ """
+ super().__init__()
+ self.norm_img = norm_img
+
+ self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ ), torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ ), torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ ), torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ ), torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )])
+
+ self.score_nets = nn.ModuleList([
+ torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
+ for i in [64, 128, 256, 512, 512]
+ ])
+
+ self.combine_net = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+ torch.nn.Sigmoid()
+ )
+
+ self.load_weights(pretrained_model_path)
+ self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
+ for v in self.parameters():
+ v.requies_grad = False
+
+ def load_weights(self, pretrained_model_path):
+ checkpoint_path = Path(pretrained_model_path)
+ if not checkpoint_path.exists():
+ raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
+ ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
+ m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
+
+ def replace_key(key):
+ if key.startswith("moduleVgg"):
+ return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
+ elif key.startswith("moduleScore"):
+ return f"score_nets.{m[key[11:14]]}{key[14:]}"
+ elif key.startswith("moduleCombine"):
+ return f"combine_net{key[13:]}"
+ else:
+ raise ValueError("wrong checkpoint for HED")
+
+ module_dict = {replace_key(k): v for k, v in ckp.items()}
+ self.load_state_dict(module_dict, strict=True)
+
+ def forward(self, x):
+ if self.norm_img:
+ x = (x + 1.) * 0.5
+ x = x * 255.0 - self.mean
+ img_size = (x.size(2), x.size(3))
+
+ to_combine = []
+ for i in range(5):
+ x = self.vgg_nets[i](x)
+ score_x = self.score_nets[i](x)
+ to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
+ out = self.combine_net(torch.cat(to_combine, 1))
+ return out.clamp(0.0, 1.0)
+
+
+class EdgeLoss(nn.Module):
+ def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
+ super(EdgeLoss, self).__init__()
+ if edge_extractor_type == "HED":
+ pretrained_model_path = kwargs.get("hed_pretrained_model_path")
+ self.edge_extractor = HED(pretrained_model_path, norm_img)
+ else:
+ raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
+
+ if criterion == 'L1':
+ self.criterion = nn.L1Loss()
+ elif criterion == "L2":
+ self.criterion = nn.MSELoss()
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
+
+ def forward(self, x, gt, gt_is_edge=True):
+ edge = self.edge_extractor(x)
+ if not gt_is_edge:
+ gt = self.edge_extractor(gt.detach())
+ loss = self.criterion(edge, gt)
+ return loss
diff --git a/loss/I2I/perceptual_loss.py b/loss/I2I/perceptual_loss.py
new file mode 100644
index 0000000..a390063
--- /dev/null
+++ b/loss/I2I/perceptual_loss.py
@@ -0,0 +1,155 @@
+import torch
+import torch.nn as nn
+import torchvision.models.vgg as vgg
+
+
+class PerceptualVGG(nn.Module):
+ """VGG network used in calculating perceptual loss.
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+ Args:
+ layer_name_list (list[str]): According to the index in this list,
+ forward function will return the corresponding features. This
+ list contains the name each layer in `vgg.feature`. An example
+ of this list is ['4', '10'].
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image.
+ Importantly, the input feature must in the range [0, 1].
+ Default: True.
+ """
+
+ def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
+ super(PerceptualVGG, self).__init__()
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+
+ # get vgg model and load pretrained vgg weight
+ # remove _vgg from attributes to avoid `find_unused_parameters` bug
+ _vgg = getattr(vgg, vgg_type)(pretrained=True)
+ num_layers = max(map(int, layer_name_list)) + 1
+ assert len(_vgg.features) >= num_layers
+ # only borrow layers that will be used from _vgg to avoid unused params
+ self.vgg_layers = _vgg.features[:num_layers]
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer(
+ 'mean',
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [-1, 1]
+ self.register_buffer(
+ 'std',
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ for v in self.vgg_layers.parameters():
+ v.requies_grad = False
+
+ def forward(self, x):
+ """Forward function.
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ Returns:
+ Tensor: Forward results.
+ """
+
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for i, l in enumerate(self.vgg_layers):
+ x = l(x)
+ if str(i) in self.layer_name_list:
+ output[str(i)] = x.clone()
+
+ return output
+
+
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
+ 5th, 10th and 18th feature layer will be extracted with weight 1.0
+ in calculating losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
+ loss will be calculated.
+ Default: True.
+ style_loss (bool): If `style_loss == False`, the style loss will be calculated.
+ Default: False.
+ norm_img (bool): If True, the image will be normed to [0, 1]. Note that
+ this is different from the `use_input_norm` which norm the input in
+ in forward function of vgg according to the statistics of dataset.
+ Importantly, the input image must be in range [-1, 1].
+ """
+
+ def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
+ style_loss=False, norm_img=True, criterion='L1'):
+ super(PerceptualLoss, self).__init__()
+ self.norm_img = norm_img
+ self.perceptual_loss = perceptual_loss
+ self.style_loss = style_loss
+ self.layer_weights = layer_weights
+ self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
+ use_input_norm=use_input_norm)
+
+ if criterion == 'L1':
+ self.criterion = torch.nn.L1Loss()
+ elif criterion == "L2":
+ self.criterion = torch.nn.MSELoss()
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
+
+ def forward(self, x, gt):
+ """Forward function.
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+ Returns:
+ Tensor: Forward results.
+ """
+
+ if self.norm_img:
+ x = (x + 1.) * 0.5
+ gt = (gt + 1.) * 0.5
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate preceptual loss
+ if self.perceptual_loss:
+ percep_loss = 0
+ for k in x_features.keys():
+ percep_loss += self.criterion(
+ x_features[k], gt_features[k]) * self.layer_weights[k]
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_loss:
+ style_loss = 0
+ for k in x_features.keys():
+ style_loss += self.criterion(
+ self._gram_mat(x_features[k]),
+ self._gram_mat(gt_features[k])) * self.layer_weights[k]
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ (n, c, h, w) = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
diff --git a/model/GAN/TAHG.py b/model/GAN/TAHG.py
index 963230c..63c092a 100644
--- a/model/GAN/TAHG.py
+++ b/model/GAN/TAHG.py
@@ -142,7 +142,7 @@ class Fusion(nn.Module):
@MODEL.register_module("TAHG-Generator")
class Generator(nn.Module):
- def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8,
+ def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
@@ -175,3 +175,38 @@ class Generator(nn.Module):
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
+
+
+@MODEL.register_module("TAHG-Discriminator")
+class Discriminator(nn.Module):
+ def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
+ padding_mode="reflect"):
+ super(Discriminator, self).__init__()
+
+ norm_layer = select_norm_layer(norm_type)
+ use_bias = norm_type == "IN"
+
+ sequence = [nn.Sequential(
+ nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
+ bias=use_bias),
+ norm_layer(num_features=base_channels),
+ nn.ReLU(inplace=True)
+ )]
+ # stacked intermediate layers,
+ # gradually increasing the number of filters
+ multiple_now = 1
+ for n in range(1, num_down_sampling + 1):
+ multiple_prev = multiple_now
+ multiple_now = min(2 ** n, 4)
+ sequence += [
+ nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
+ padding=1, stride=2, bias=use_bias),
+ norm_layer(base_channels * multiple_now),
+ nn.LeakyReLU(0.2, inplace=True)
+ ]
+ for _ in range(num_blocks):
+ sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, x):
+ return self.model(x)
diff --git a/model/__init__.py b/model/__init__.py
index cfbd292..08e1dfe 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -1,3 +1,5 @@
from model.registry import MODEL
import model.GAN.residual_generator
+import model.GAN.TAHG
+import model.GAN.UGATIT
import model.fewshot
diff --git a/model/normalization.py b/model/normalization.py
index acfbbbd..fd7a2d8 100644
--- a/model/normalization.py
+++ b/model/normalization.py
@@ -37,7 +37,6 @@ class LayerNorm2d(nn.Module):
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
- print(x.size())
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x