diff --git a/.idea/other.xml b/.idea/other.xml new file mode 100644 index 0000000..8339243 --- /dev/null +++ b/.idea/other.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/configs/synthesizers/TAHG.yml b/configs/synthesizers/TAHG.yml new file mode 100644 index 0000000..3395036 --- /dev/null +++ b/configs/synthesizers/TAHG.yml @@ -0,0 +1,127 @@ +name: TAHG +engine: TAHG +result_dir: ./result +max_pairs: 1000000 + +distributed: + model: + # broadcast_buffers: False + +misc: + random_seed: 324 + +checkpoint: + epoch_interval: 1 # one checkpoint every 1 epoch + n_saved: 2 + +interval: + print_per_iteration: 10 # print once per 10 iteration + tensorboard: + scalar: 10 + image: 500 + +model: + generator: + _type: TAHG-Generator + style_in_channels: 3 + content_in_channels: 23 + discriminator: + _type: TAHG-Discriminator + in_channels: 3 + +loss: + gan: + loss_type: lsgan + real_label_val: 1.0 + fake_label_val: 0.0 + weight: 1.0 + edge: + criterion: 'L1' + hed_pretrained_model_path: "/root/network-bsds500.pytorch" + weight: 2 + perceptual: + layer_weights: +# "3": 1.0 + "0": 1.0 + "5": 1.0 + "10": 1.0 + "19": 1.0 + criterion: 'L2' + style_loss: True + perceptual_loss: False + weight: 100.0 + recon: + level: 1 + weight: 2 + +optimizers: + generator: + _type: Adam + lr: 0.0001 + betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 + discriminator: + _type: Adam + lr: 1e-4 + betas: [ 0.5, 0.999 ] + weight_decay: 0.0001 + +data: + train: + scheduler: + start_proportion: 0.5 + target_lr: 0 + buffer_size: 50 + dataloader: + batch_size: 4 + shuffle: True + num_workers: 2 + pin_memory: True + drop_last: True + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/trainA" + root_b: "/data/i2i/VoxCeleb2Anime/trainB" + edge_type: "hed_landmark" + random_pair: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + test: + dataloader: + batch_size: 8 + shuffle: False + num_workers: 1 + pin_memory: False + drop_last: False + dataset: + _type: GenerationUnpairedDatasetWithEdge + root_a: "/data/i2i/VoxCeleb2Anime/testA" + root_b: "/data/i2i/VoxCeleb2Anime/testB" + edge_type: "hed_landmark" + random_pair: False + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + video_dataset: + _type: SingleFolderDataset + root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" + with_path: True + pipeline: + - Load + - Resize: + size: [ 256, 256 ] + - ToTensor + - Normalize: + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] diff --git a/data/dataset.py b/data/dataset.py index 37f19a9..5cb9462 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,5 +1,6 @@ import os import pickle +from pathlib import Path from collections import defaultdict import torch @@ -171,3 +172,33 @@ class GenerationUnpairedDataset(Dataset): def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}" + + +@DATASET.register_module() +class GenerationUnpairedDatasetWithEdge(Dataset): + def __init__(self, root_a, root_b, random_pair, pipeline, edge_type): + self.edge_type = edge_type + self.A = SingleFolderDataset(root_a, pipeline, with_path=True) + self.B = SingleFolderDataset(root_b, pipeline, with_path=False) + self.random_pair = random_pair + + def get_edge(self, origin_path): + op = Path(origin_path) + add = torch.load(op.parent / f"{op.stem}.add") + return {"edge": add["edge"].float().unsqueeze(dim=0), + "additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)} + + def __getitem__(self, idx): + a_idx = idx % len(self.A) + b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() + output = dict() + output["a"], path_a = self.A[a_idx] + output.update(self.get_edge(path_a)) + output["b"] = self.B[b_idx] + return output + + def __len__(self): + return max(len(self.A), len(self.B)) + + def __repr__(self): + return f"\nPipeline:\n{self.A.pipeline}" diff --git a/engine/TAHG.py b/engine/TAHG.py new file mode 100644 index 0000000..660aca3 --- /dev/null +++ b/engine/TAHG.py @@ -0,0 +1,204 @@ +from itertools import chain +from math import ceil +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +import ignite.distributed as idist +from ignite.engine import Events, Engine +from ignite.metrics import RunningAverage +from ignite.utils import convert_tensor +from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler +from ignite.contrib.handlers.param_scheduler import PiecewiseLinear + +from omegaconf import OmegaConf, read_write + +import data +from loss.gan import GANLoss +from model.weight_init import generation_init_weights +from model.GAN.residual_generator import GANImageBuffer +from loss.I2I.edge_loss import EdgeLoss +from loss.I2I.perceptual_loss import PerceptualLoss +from util.image import make_2d_grid +from util.handler import setup_common_handlers, setup_tensorboard_handler +from util.build import build_model, build_optimizer + + +def build_lr_schedulers(optimizers, config): + g_milestones_values = [ + (0, config.optimizers.generator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + d_milestones_values = [ + (0, config.optimizers.discriminator.lr), + (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), + (config.max_iteration, config.data.train.scheduler.target_lr) + ] + return dict( + g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), + d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) + ) + + +def get_trainer(config, logger): + generator = build_model(config.model.generator, config.distributed.model) + discriminators = dict( + a=build_model(config.model.discriminator, config.distributed.model), + b=build_model(config.model.discriminator, config.distributed.model), + ) + generation_init_weights(generator) + for m in discriminators.values(): + generation_init_weights(m) + + logger.debug(discriminators["a"]) + logger.debug(generator) + + optimizers = dict( + g=build_optimizer(generator.parameters(), config.optimizers.generator), + d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), + ) + logger.info(f"build optimizers:\n{optimizers}") + + lr_schedulers = build_lr_schedulers(optimizers, config) + logger.info(f"build lr_schedulers:\n{lr_schedulers}") + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + + edge_loss_cfg = OmegaConf.to_container(config.loss.edge) + edge_loss_cfg.pop("weight") + edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device()) + + perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) + perceptual_loss_cfg.pop("weight") + perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) + + recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() + + image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()} + + def _step(engine, batch): + batch = convert_tensor(batch, idist.device()) + real = dict(a=batch["a"], b=batch["b"]) + edge = batch["edge"] + additional_info = batch["additional_info"] + content_img = torch.cat([edge, additional_info], dim=1) + fake = dict( + a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"), + b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"), + ) + + optimizers["g"].zero_grad() + loss_g = dict() + for d in "ab": + discriminators[d].requires_grad_(False) + pred_fake = discriminators[d](fake[d]) + loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True) + _, t = perceptual_loss(fake[d], real[d]) + loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t + loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False) + loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"]) + sum(loss_g.values()).backward() + optimizers["g"].step() + + for discriminator in discriminators.values(): + discriminator.requires_grad_(True) + + optimizers["d"].zero_grad() + loss_d = dict() + for k in discriminators.keys(): + pred_real = discriminators[k](real[k]) + pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach())) + loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) + + gan_loss(pred_fake, False, is_discriminator=True)) / 2 + sum(loss_d.values()).backward() + optimizers["d"].step() + + generated_img = {f"real_{k}": real[k].detach() for k in real} + generated_img.update({f"fake_{k}": fake[k].detach() for k in fake}) + return { + "loss": { + "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, + "d": {ln: loss_d[ln].mean().item() for ln in loss_d}, + }, + "img": generated_img + } + + trainer = Engine(_step) + trainer.logger = logger + for lr_shd in lr_schedulers.values(): + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) + + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") + + to_save = dict(trainer=trainer) + to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) + to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) + to_save.update({"generator": generator}) + to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) + setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, + end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) + + def output_transform(output): + loss = dict() + for tl in output["loss"]: + if isinstance(output["loss"][tl], dict): + for l in output["loss"][tl]: + loss[f"{tl}_{l}"] = output["loss"][tl][l] + else: + loss[tl] = output["loss"][tl] + return loss + + tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform) + if tensorboard_handler is not None: + tensorboard_handler.attach( + trainer, + log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), + event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) + ) + + @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) + def show_images(engine): + output = engine.state.output + image_order = dict( + a=["real_a", "fake_a"], + b=["real_b", "fake_b"] + ) + for k in "ab": + tensorboard_handler.writer.add_image( + f"train/{k}", + make_2d_grid([output["img"][o] for o in image_order[k]]), + engine.state.iteration + ) + + return trainer + + +def run(task, config, logger): + assert torch.backends.cudnn.enabled + torch.backends.cudnn.benchmark = True + logger.info(f"start task {task}") + with read_write(config): + config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) + + if task == "train": + train_dataset = data.DATASET.build_with(config.data.train.dataset) + logger.info(f"train with dataset:\n{train_dataset}") + train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) + trainer = get_trainer(config, logger) + if idist.get_rank() == 0: + test_dataset = data.DATASET.build_with(config.data.test.dataset) + trainer.state.test_dataset = test_dataset + try: + trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) + except Exception: + import traceback + print(traceback.format_exc()) + else: + return NotImplemented(f"invalid task: {task}") diff --git a/loss/I2I/__init__.py b/loss/I2I/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/I2I/edge_loss.py b/loss/I2I/edge_loss.py new file mode 100644 index 0000000..a10f751 --- /dev/null +++ b/loss/I2I/edge_loss.py @@ -0,0 +1,129 @@ +from pathlib import Path + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class HED(nn.Module): + def __init__(self, pretrained_model_path, norm_img=True): + """ + HED module to get edge + :param pretrained_model_path: path to pretrained HED. + :param norm_img(bool): If True, the image will be normed to [0, 1]. Note that + this is different from the `use_input_norm` which norm the input in + in forward function of vgg according to the statistics of dataset. + Importantly, the input image must be in range [-1, 1]. + """ + super().__init__() + self.norm_img = norm_img + + self.vgg_nets = nn.ModuleList([torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ), torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ), torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ), torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ), torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + )]) + + self.score_nets = nn.ModuleList([ + torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0) + for i in [64, 128, 256, 512, 512] + ]) + + self.combine_net = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), + torch.nn.Sigmoid() + ) + + self.load_weights(pretrained_model_path) + self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1)) + for v in self.parameters(): + v.requies_grad = False + + def load_weights(self, pretrained_model_path): + checkpoint_path = Path(pretrained_model_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") + ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") + m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"} + + def replace_key(key): + if key.startswith("moduleVgg"): + return f"vgg_nets.{m[key[9:12]]}{key[12:]}" + elif key.startswith("moduleScore"): + return f"score_nets.{m[key[11:14]]}{key[14:]}" + elif key.startswith("moduleCombine"): + return f"combine_net{key[13:]}" + else: + raise ValueError("wrong checkpoint for HED") + + module_dict = {replace_key(k): v for k, v in ckp.items()} + self.load_state_dict(module_dict, strict=True) + + def forward(self, x): + if self.norm_img: + x = (x + 1.) * 0.5 + x = x * 255.0 - self.mean + img_size = (x.size(2), x.size(3)) + + to_combine = [] + for i in range(5): + x = self.vgg_nets[i](x) + score_x = self.score_nets[i](x) + to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False)) + out = self.combine_net(torch.cat(to_combine, 1)) + return out.clamp(0.0, 1.0) + + +class EdgeLoss(nn.Module): + def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs): + super(EdgeLoss, self).__init__() + if edge_extractor_type == "HED": + pretrained_model_path = kwargs.get("hed_pretrained_model_path") + self.edge_extractor = HED(pretrained_model_path, norm_img) + else: + raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}") + + if criterion == 'L1': + self.criterion = nn.L1Loss() + elif criterion == "L2": + self.criterion = nn.MSELoss() + else: + raise NotImplementedError(f'{criterion} criterion has not been supported in this version.') + + def forward(self, x, gt, gt_is_edge=True): + edge = self.edge_extractor(x) + if not gt_is_edge: + gt = self.edge_extractor(gt.detach()) + loss = self.criterion(edge, gt) + return loss diff --git a/loss/I2I/perceptual_loss.py b/loss/I2I/perceptual_loss.py new file mode 100644 index 0000000..a390063 --- /dev/null +++ b/loss/I2I/perceptual_loss.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +import torchvision.models.vgg as vgg + + +class PerceptualVGG(nn.Module): + """VGG network used in calculating perceptual loss. + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + Args: + layer_name_list (list[str]): According to the index in this list, + forward function will return the corresponding features. This + list contains the name each layer in `vgg.feature`. An example + of this list is ['4', '10']. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. + Importantly, the input feature must in the range [0, 1]. + Default: True. + """ + + def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True): + super(PerceptualVGG, self).__init__() + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + + # get vgg model and load pretrained vgg weight + # remove _vgg from attributes to avoid `find_unused_parameters` bug + _vgg = getattr(vgg, vgg_type)(pretrained=True) + num_layers = max(map(int, layer_name_list)) + 1 + assert len(_vgg.features) >= num_layers + # only borrow layers that will be used from _vgg to avoid unused params + self.vgg_layers = _vgg.features[:num_layers] + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [-1, 1] + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + for v in self.vgg_layers.parameters(): + v.requies_grad = False + + def forward(self, x): + """Forward function. + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + Returns: + Tensor: Forward results. + """ + + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for i, l in enumerate(self.vgg_layers): + x = l(x) + if str(i) in self.layer_name_list: + output[str(i)] = x.clone() + + return output + + +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the + 5th, 10th and 18th feature layer will be extracted with weight 1.0 + in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + perceptual_loss (bool): If `perceptual_loss == True`, the perceptual + loss will be calculated. + Default: True. + style_loss (bool): If `style_loss == False`, the style loss will be calculated. + Default: False. + norm_img (bool): If True, the image will be normed to [0, 1]. Note that + this is different from the `use_input_norm` which norm the input in + in forward function of vgg according to the statistics of dataset. + Importantly, the input image must be in range [-1, 1]. + """ + + def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True, + style_loss=False, norm_img=True, criterion='L1'): + super(PerceptualLoss, self).__init__() + self.norm_img = norm_img + self.perceptual_loss = perceptual_loss + self.style_loss = style_loss + self.layer_weights = layer_weights + self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, + use_input_norm=use_input_norm) + + if criterion == 'L1': + self.criterion = torch.nn.L1Loss() + elif criterion == "L2": + self.criterion = torch.nn.MSELoss() + else: + raise NotImplementedError(f'{criterion} criterion has not been supported in this version.') + + def forward(self, x, gt): + """Forward function. + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + Returns: + Tensor: Forward results. + """ + + if self.norm_img: + x = (x + 1.) * 0.5 + gt = (gt + 1.) * 0.5 + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate preceptual loss + if self.perceptual_loss: + percep_loss = 0 + for k in x_features.keys(): + percep_loss += self.criterion( + x_features[k], gt_features[k]) * self.layer_weights[k] + else: + percep_loss = None + + # calculate style loss + if self.style_loss: + style_loss = 0 + for k in x_features.keys(): + style_loss += self.criterion( + self._gram_mat(x_features[k]), + self._gram_mat(gt_features[k])) * self.layer_weights[k] + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + Returns: + torch.Tensor: Gram matrix. + """ + (n, c, h, w) = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/model/GAN/TAHG.py b/model/GAN/TAHG.py index 963230c..63c092a 100644 --- a/model/GAN/TAHG.py +++ b/model/GAN/TAHG.py @@ -142,7 +142,7 @@ class Fusion(nn.Module): @MODEL.register_module("TAHG-Generator") class Generator(nn.Module): - def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8, + def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"): super(Generator, self).__init__() self.num_blocks = num_blocks @@ -175,3 +175,38 @@ class Generator(nn.Module): ar.norm2.set_style(styles[2 * i + 1]) x = ar(x) return self.decoders[which_decoder](x) + + +@MODEL.register_module("TAHG-Discriminator") +class Discriminator(nn.Module): + def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN", + padding_mode="reflect"): + super(Discriminator, self).__init__() + + norm_layer = select_norm_layer(norm_type) + use_bias = norm_type == "IN" + + sequence = [nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, + bias=use_bias), + norm_layer(num_features=base_channels), + nn.ReLU(inplace=True) + )] + # stacked intermediate layers, + # gradually increasing the number of filters + multiple_now = 1 + for n in range(1, num_down_sampling + 1): + multiple_prev = multiple_now + multiple_now = min(2 ** n, 4) + sequence += [ + nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3, + padding=1, stride=2, bias=use_bias), + norm_layer(base_channels * multiple_now), + nn.LeakyReLU(0.2, inplace=True) + ] + for _ in range(num_blocks): + sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type)) + self.model = nn.Sequential(*sequence) + + def forward(self, x): + return self.model(x) diff --git a/model/__init__.py b/model/__init__.py index cfbd292..08e1dfe 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,3 +1,5 @@ from model.registry import MODEL import model.GAN.residual_generator +import model.GAN.TAHG +import model.GAN.UGATIT import model.fewshot diff --git a/model/normalization.py b/model/normalization.py index acfbbbd..fd7a2d8 100644 --- a/model/normalization.py +++ b/model/normalization.py @@ -37,7 +37,6 @@ class LayerNorm2d(nn.Module): def forward(self, x): ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) x = (x - ln_mean) / torch.sqrt(ln_var + self.eps) - print(x.size()) if self.affine: return self.channel_gamma * x + self.channel_beta return x