From 61e04de8a5b0f1e8e823a96dd66af46edf2d5cf7 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Thu, 17 Sep 2020 09:34:53 +0800 Subject: [PATCH] TAFG --- .idea/deployment.xml | 2 +- configs/synthesizers/TAFG.yml | 38 +++--- data/dataset.py | 15 +-- engine/MUNIT.py | 26 ++-- engine/TAFG.py | 106 +++++++++------ engine/base/i2i.py | 15 ++- loss/I2I/perceptual_loss.py | 9 +- model/GAN/TAFG.py | 243 +++++++--------------------------- model/GAN/base.py | 2 +- 9 files changed, 168 insertions(+), 288 deletions(-) diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 8ccfb5e..0198fe4 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index 2d85a03..939681f 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -1,4 +1,4 @@ -name: TAFG +name: TAFG-vox2 engine: TAFG result_dir: ./result max_pairs: 1500000 @@ -11,11 +11,11 @@ handler: n_saved: 2 tensorboard: scalar: 100 # log scalar `scalar` times per epoch - image: 2 # log image `image` times per epoch + image: 4 # log image `image` times per epoch misc: - random_seed: 324 + random_seed: 123 model: generator: @@ -24,7 +24,9 @@ model: style_in_channels: 3 content_in_channels: 24 num_adain_blocks: 8 - num_res_blocks: 0 + num_res_blocks: 8 + use_spectral_norm: True + style_use_fc: False discriminator: _type: MultiScaleDiscriminator num_scale: 2 @@ -51,26 +53,22 @@ loss: criterion: 'L1' style_loss: False perceptual_loss: True - weight: 10 - style: - layer_weights: - "3": 1 - criterion: 'L1' - style_loss: True - perceptual_loss: False - weight: 10 - fm: - level: 1 - weight: 10 + weight: 0 recon: level: 1 weight: 10 style_recon: level: 1 - weight: 0 + weight: 5 + content_recon: + level: 1 + weight: 10 edge: weight: 10 hed_pretrained_model_path: ./network-bsds500.pytorch + cycle: + level: 1 + weight: 10 optimizers: generator: @@ -91,9 +89,9 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 8 + batch_size: 1 shuffle: True - num_workers: 2 + num_workers: 1 pin_memory: True drop_last: True dataset: @@ -116,7 +114,7 @@ data: test: which: video_dataset dataloader: - batch_size: 8 + batch_size: 1 shuffle: False num_workers: 1 pin_memory: False @@ -145,7 +143,7 @@ data: pipeline: - Load - Resize: - size: [ 256, 256 ] + size: [ 128, 128 ] - ToTensor - Normalize: mean: [ 0.5, 0.5, 0.5 ] diff --git a/data/dataset.py b/data/dataset.py index 23f074f..2941889 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset): op = Path(origin_path) if self.edge_type.startswith("landmark_"): edge_type = self.edge_type.lstrip("landmark_") - use_landmark = True + use_landmark = op.parent.name.endswith("A") else: edge_type = self.edge_type use_landmark = False @@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset): def __getitem__(self, idx): a_idx = idx % len(self.A) b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() - if self.with_path: - output = {"a": self.A[a_idx], "b": self.B[b_idx]} - output["edge_a"] = output["a"][1] - return output - output = dict() - output["a"], path_a = self.A[a_idx] - output["b"], path_b = self.B[b_idx] - output["edge_a"] = self.get_edge(path_a) + output = dict(a={}, b={}) + output["a"]["img"], output["a"]["path"] = self.A[a_idx] + output["b"]["img"], output["b"]["path"] = self.B[b_idx] + for p in "ab": + output[p]["edge"] = self.get_edge(output[p]["path"]) return output def __len__(self): diff --git a/engine/MUNIT.py b/engine/MUNIT.py index a0ae713..8d8f41a 100644 --- a/engine/MUNIT.py +++ b/engine/MUNIT.py @@ -58,20 +58,20 @@ class MUNITEngineKernel(EngineKernel): styles = dict() contents = dict() images = dict() + with torch.set_grad_enabled(not inference): + for phase in "ab": + contents[phase], styles[phase] = self.generators[phase].encode(batch[phase]) + images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase]) + styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device()) - for phase in "ab": - contents[phase], styles[phase] = self.generators[phase].encode(batch[phase]) - images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase]) - styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device()) - - for phase in ("a2b", "b2a"): - # images["a2b"] = Gb.decode(content_a, random_style_b) - images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"]) - # contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"]) - contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase]) - if self.config.loss.recon.cycle.weight > 0: - images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]]) - return dict(styles=styles, contents=contents, images=images) + for phase in ("a2b", "b2a"): + # images["a2b"] = Gb.decode(content_a, random_style_b) + images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"]) + # contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"]) + contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase]) + if self.config.loss.recon.cycle.weight > 0: + images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]]) + return dict(styles=styles, contents=contents, images=images) def criterion_generators(self, batch, generated) -> dict: loss = dict() diff --git a/engine/TAFG.py b/engine/TAFG.py index 20a253c..5871d28 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -3,8 +3,7 @@ from itertools import chain import ignite.distributed as idist import torch import torch.nn as nn -from ignite.engine import Events -from omegaconf import read_write, OmegaConf +from omegaconf import OmegaConf from engine.base.i2i import EngineKernel, run_kernel from engine.util.build import build_model @@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel): perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) - style_loss_cfg = OmegaConf.to_container(config.loss.style) - style_loss_cfg.pop("weight") - self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device()) - gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg.pop("weight") self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() + self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss() + self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( idist.device()) @@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel): def forward(self, batch, inference=False) -> dict: generator = self.generators["main"] batch = self._process_batch(batch, inference) + + styles = dict() + contents = dict() + images = dict() with torch.set_grad_enabled(not inference): - fake = dict( - a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), - b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"), - ) - return fake + for ph in "ab": + contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph) + for ph in ("a2b", "b2a"): + images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1]) + contents["recon_a"], styles["recon_b"] = generator.encode( + self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b") + images["a2a"] = generator.decode(contents["a"], styles["a"], "a") + images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b") + images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a") + return dict(styles=styles, contents=contents, images=images) def criterion_generators(self, batch, generated) -> dict: batch = self._process_batch(batch) loss = dict() - loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) - _, loss_style = self.style_loss(generated["a"], batch["a"]) - loss["style"] = self.config.loss.style.weight * loss_style - loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual - for phase in "ab": - pred_fake = self.discriminators[phase](generated[phase]) - loss[f"gan_{phase}"] = 0 + + for ph in "ab": + loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss( + generated["images"][f"{ph}2{ph}"], batch[ph]["img"]) + + pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"]) + loss[f"gan_{ph}"] = 0 for sub_pred_fake in pred_fake: # last output is actual prediction - loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True) + loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight + loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( + generated["contents"]["a"], generated["contents"]["recon_a"] + ) + loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( + generated["styles"]["b"], generated["styles"]["recon_b"] + ) - if self.config.loss.fm.weight > 0 and phase == "b": - pred_real = self.discriminators[phase](batch[phase]) - loss_fm = 0 - num_scale_discriminator = len(pred_fake) - for i in range(num_scale_discriminator): - # last output is the final prediction, so we exclude it - num_intermediate_outputs = len(pred_fake[i]) - 1 - for j in range(num_intermediate_outputs): - loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator - loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm - loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) - loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :]) + for ph in ("a2b", "b2a"): + if self.config.loss.perceptual.weight > 0: + loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss( + batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"] + ) + if self.config.loss.edge.weight > 0: + loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss( + generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :] + ) + loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss( + generated["images"]["fake_a"], batch["b"]["edge"] + ) + + if self.config.loss.cycle.weight > 0: + loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss( + batch["a"]["img"], generated["images"]["cycle_a"] + ) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() # batch = self._process_batch(batch) for phase in self.discriminators.keys(): - pred_real = self.discriminators[phase](batch[phase]) - pred_fake = self.discriminators[phase](generated[phase].detach()) + pred_real = self.discriminators[phase](batch[phase]["img"]) + pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach()) loss[f"gan_{phase}"] = 0 for i in range(len(pred_fake)): loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) @@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel): :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ batch = self._process_batch(batch) - edge = batch["edge_a"][:, 0:1, :, :] return dict( - a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(), - generated["b"].detach()] + a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), + batch["a"]["img"].detach(), + generated["images"]["a2a"].detach(), + generated["images"]["fake_b"].detach(), + generated["images"]["cycle_a"].detach(), + ], + b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(), + batch["b"]["img"].detach(), + generated["images"]["b2b"].detach(), + generated["images"]["fake_a"].detach()] ) def change_engine(self, config, trainer): - @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) - def change_config(engine): - with read_write(config): - config.loss.perceptual.weight = 5 + pass + # @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) + # def change_config(engine): + # with read_write(config): + # config.loss.perceptual.weight = 5 def run(task, config, _): diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 414db76..b27df75 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel): generated = kernel.forward(batch) if kernel.train_generator_first: + # simultaneous, train G with simultaneous D loss_g = train_generators(batch, generated) loss_d = train_discriminators(batch, generated) else: + # update discriminators first, not simultaneous. + # train G with updated discriminators loss_d = train_discriminators(batch, generated) loss_g = train_generators(batch, generated) @@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel): kernel.change_engine(config, trainer) - RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") - RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") + RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g") + RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d") to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) @@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel): for i in range(random_start, random_start + 10): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: - batch[k] = batch[k].view(1, *batch[k].size()) + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].unsqueeze(0) + elif isinstance(batch[k], dict): + for kk in batch[k]: + if isinstance(batch[k][kk], torch.Tensor): + batch[k][kk] = batch[k][kk].unsqueeze(0) + generated = kernel.forward(batch) images = kernel.intermediate_images(batch, generated) diff --git a/loss/I2I/perceptual_loss.py b/loss/I2I/perceptual_loss.py index 378b699..e55aaa1 100644 --- a/loss/I2I/perceptual_loss.py +++ b/loss/I2I/perceptual_loss.py @@ -92,6 +92,7 @@ class PerceptualLoss(nn.Module): style_loss=False, norm_img=True, criterion='L1'): super(PerceptualLoss, self).__init__() self.norm_img = norm_img + assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual" self.perceptual_loss = perceptual_loss self.style_loss = style_loss self.layer_weights = layer_weights @@ -127,8 +128,7 @@ class PerceptualLoss(nn.Module): percep_loss = 0 for k in x_features.keys(): percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k] - else: - percep_loss = None + return percep_loss # calculate style loss if self.style_loss: @@ -136,10 +136,7 @@ class PerceptualLoss(nn.Module): for k in x_features.keys(): style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \ self.layer_weights[k] - else: - style_loss = None - - return percep_loss, style_loss + return style_loss def _gram_mat(self, x): """Calculate Gram matrix. diff --git a/model/GAN/TAFG.py b/model/GAN/TAFG.py index 33aed3f..e8976c0 100644 --- a/model/GAN/TAFG.py +++ b/model/GAN/TAFG.py @@ -4,16 +4,17 @@ from torchvision.models import vgg19 from model.normalization import select_norm_layer from model.registry import MODEL -from .base import ResidualBlock +from .MUNIT import ContentEncoder, Fusion, Decoder +from .base import ResBlock class VGG19StyleEncoder(nn.Module): def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE", - vgg19_layers=(0, 5, 10, 19)): + vgg19_layers=(0, 5, 10, 19), fix_vgg19=True): super().__init__() self.vgg19_layers = vgg19_layers self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1] - self.vgg19.requires_grad_(False) + self.vgg19.requires_grad_(not fix_vgg19) norm_layer = select_norm_layer(norm_type) @@ -52,203 +53,57 @@ class VGG19StyleEncoder(nn.Module): return x.view(x.size(0), -1) -class ContentEncoder(nn.Module): - def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"): - super().__init__() - norm_layer = select_norm_layer(norm_type) - - self.start_conv = nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, - bias=True), - norm_layer(num_features=base_channels), - nn.ReLU(inplace=True) - ) - - # down sampling - submodules = [] - num_down_sampling = 2 - for i in range(num_down_sampling): - multiple = 2 ** i - submodules += [ - nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, - kernel_size=4, stride=2, padding=1, bias=True), - norm_layer(num_features=base_channels * multiple * 2), - nn.ReLU(inplace=True) - ] - self.encoder = nn.Sequential(*submodules) - res_block_channels = num_down_sampling ** 2 * base_channels - self.resnet = nn.Sequential( - *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)]) - - def forward(self, x): - x = self.start_conv(x) - x = self.encoder(x) - x = self.resnet(x) - return x - - -class Decoder(nn.Module): - def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect', - norm_type="LN"): - super(Decoder, self).__init__() - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - - res_block_channels = (2 ** 2) * base_channels - - self.resnet = nn.Sequential( - *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)]) - - # up sampling - submodules = [] - for i in range(num_down_sampling): - multiple = 2 ** (num_down_sampling - i) - submodules += [ - nn.Upsample(scale_factor=2), - nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1, - padding=2, padding_mode=padding_mode, bias=use_bias), - norm_layer(num_features=base_channels * multiple // 2), - nn.ReLU(inplace=True), - ] - self.decoder = nn.Sequential(*submodules) - self.end_conv = nn.Sequential( - nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode), - nn.Tanh() - ) - - def forward(self, x): - x = self.resnet(x) - x = self.decoder(x) - x = self.end_conv(x) - return x - - -class Fusion(nn.Module): - def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"): - super().__init__() - norm_layer = select_norm_layer(norm_type) - self.start_fc = nn.Sequential( - nn.Linear(in_features, base_features), - norm_layer(base_features), - nn.ReLU(True), - ) - self.fcs = nn.Sequential(*[ - nn.Sequential( - nn.Linear(base_features, base_features), - norm_layer(base_features), - nn.ReLU(True), - ) for _ in range(n_blocks - 2) - ]) - self.end_fc = nn.Sequential( - nn.Linear(base_features, out_features), - ) - - def forward(self, x): - x = self.start_fc(x) - x = self.fcs(x) - return self.end_fc(x) - - -class StyleGenerator(nn.Module): - def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"): - super().__init__() - self.num_blocks = num_blocks - self.style_encoder = VGG19StyleEncoder( - style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE") - self.fc = nn.Sequential( - nn.Linear(style_dim, style_dim), - nn.ReLU(True), - ) - res_block_channels = 2 ** 2 * base_channels - self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3, - norm_type="NONE") - - def forward(self, x): - styles = self.fusion(self.fc(self.style_encoder(x))) - return styles - - @MODEL.register_module("TAFG-Generator") class Generator(nn.Module): - def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, - num_adain_blocks=8, num_res_blocks=4, + def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False, + style_dim=512, style_use_fc=True, + num_adain_blocks=8, num_res_blocks=8, base_channels=64, padding_mode="reflect"): super(Generator, self).__init__() - self.num_adain_blocks=num_adain_blocks - self.style_encoders = nn.ModuleDict({ - "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, - base_channels=base_channels, padding_mode=padding_mode), - "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, - base_channels=base_channels, padding_mode=padding_mode), - }) - self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8, - padding_mode=padding_mode, norm_type="IN") - res_block_channels = 2 ** 2 * base_channels - - self.resnet = nn.ModuleDict({ - "a": nn.Sequential(*[ - ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) - ]), - "b": nn.Sequential(*[ - ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) - ]) - }) - self.adain_resnet = nn.ModuleDict({ - "a": nn.ModuleList([ - ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks) - ]), - "b": nn.ModuleList([ - ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks) - ]) + self.num_adain_blocks = num_adain_blocks + self.style_encoders = nn.ModuleDict(dict( + a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, + norm_type="NONE"), + b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, + norm_type="NONE", fix_vgg19=False) + )) + resnet_channels = 2 ** 2 * base_channels + self.style_converters = nn.ModuleDict(dict( + a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3, + norm_type="NONE"), + b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3, + norm_type="NONE"), + )) + self.content_encoders = nn.ModuleDict({ + "a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm), + "b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm) }) - self.decoders = nn.ModuleDict({ - "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), - "b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode) - }) + self.content_resnet = nn.Sequential(*[ + ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN") + for _ in range(num_res_blocks) + ]) + self.decoders = nn.ModuleDict(dict( + a=Decoder(resnet_channels, out_channels, 2, + num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode), + b=Decoder(resnet_channels, out_channels, 2, + num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode), + )) - def forward(self, content_img, style_img, which_decoder: str = "a"): - x = self.content_encoder(content_img) - x = self.resnet[which_decoder](x) - styles = self.style_encoders[which_decoder](style_img) - styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1) - for i, ar in enumerate(self.adain_resnet[which_decoder]): - ar.norm1.set_style(styles[2 * i]) - ar.norm2.set_style(styles[2 * i + 1]) - x = ar(x) - return self.decoders[which_decoder](x) + def encode(self, content_img, style_img, which_content, which_style): + content = self.content_resnet(self.content_encoders[which_content](content_img)) + style = self.style_encoders[which_style](style_img) + return content, style + def decode(self, content, style, which): + decoder = self.decoders[which] + as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1) + # set style for decoder + for i, blk in enumerate(decoder.res_blocks): + blk.conv1.normalization.set_style(as_param_style[2 * i]) + blk.conv2.normalization.set_style(as_param_style[2 * i + 1]) + return decoder(content) -@MODEL.register_module("TAFG-Discriminator") -class Discriminator(nn.Module): - def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN", - padding_mode="reflect"): - super(Discriminator, self).__init__() - - norm_layer = select_norm_layer(norm_type) - use_bias = norm_type == "IN" - - sequence = [nn.Sequential( - nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3, - bias=use_bias), - norm_layer(num_features=base_channels), - nn.ReLU(inplace=True) - )] - # stacked intermediate layers, - # gradually increasing the number of filters - multiple_now = 1 - for n in range(1, num_down_sampling + 1): - multiple_prev = multiple_now - multiple_now = min(2 ** n, 4) - sequence += [ - nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3, - padding=1, stride=2, bias=use_bias), - norm_layer(base_channels * multiple_now), - nn.LeakyReLU(0.2, inplace=True) - ] - for _ in range(num_blocks): - sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type)) - self.model = nn.Sequential(*sequence) - - def forward(self, x): - return self.model(x) + def forward(self, content_img, style_img, which_content, which_style): + content, style = self.encode(content_img, style_img, which_content, which_style) + return self.decode(content, style, which_style) diff --git a/model/GAN/base.py b/model/GAN/base.py index fb73169..856e3cd 100644 --- a/model/GAN/base.py +++ b/model/GAN/base.py @@ -185,7 +185,7 @@ class Conv2dBlock(nn.Module): class ResBlock(nn.Module): def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect', - norm_type="IN", activation_type="relu", use_bias=None): + norm_type="IN", activation_type="ReLU", use_bias=None): super().__init__() self.norm_type = norm_type if use_bias is None: