Compare commits

...

7 Commits

Author SHA1 Message Date
776fe40199 change a lot 2020-09-26 17:48:26 +08:00
f67bcdf161 use base module rewrite TSIT 2020-09-26 17:48:10 +08:00
16f18ab2e2 func to apply sn 2020-09-26 17:47:24 +08:00
0f2b67e215 base model, Norm&Conv&ResNet 2020-09-26 17:45:51 +08:00
acf243cb12 working 2020-09-25 18:31:12 +08:00
fbea96f6d7 add new dataset type 2020-09-24 16:50:53 +08:00
ca55318253 add context loss 2020-09-24 16:38:03 +08:00
21 changed files with 980 additions and 250 deletions

View File

@ -19,6 +19,7 @@ handler:
misc:
random_seed: 1004
add_new_loss_epoch: -1
model:
generator:

View File

@ -1,4 +1,4 @@
name: self2anime-TSIT
name: VoxCeleb2Anime-TSIT
engine: TSIT
result_dir: ./result
max_pairs: 1500000
@ -11,7 +11,10 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
misc:
@ -86,24 +89,23 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
batch_size: 8
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
size: [ 128, 128 ]
_type: GenerationUnpairedDataset
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
root_b: "/data/i2i/anime/your-name/faces"
random_pair: True
pipeline:
- Load
- Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
@ -118,13 +120,14 @@ data:
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
with_path: True
root_a: "/data/i2i/faces/CelebA-Asian/testA"
root_b: "/data/i2i/anime/your-name/faces"
random_pair: False
pipeline:
- Load
- Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- ToTensor
- Normalize:

View File

@ -0,0 +1,171 @@
name: talking_anime
engine: talking_anime
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 100 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 1004
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
fm:
level: 1
weight: 1
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
context:
layer_weights:
#"13": 1
"22": 1
weight: 5
recon:
level: 1
weight: 10
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
model:
face_generator:
_type: TAFG-SingleGenerator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
use_spectral_norm: True
style_encoder_type: VGG19StyleEncoder
num_style_conv: 4
style_dim: 512
num_adain_blocks: 8
num_res_blocks: 8
anime_generator:
_type: TAFG-ResGenerator
_bn_to_sync_bn: False
in_channels: 6
use_spectral_norm: True
num_res_blocks: 8
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
dataloader:
batch_size: 8
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 144, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/testA"
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -1,6 +1,7 @@
import os
import pickle
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import lmdb
@ -237,3 +238,50 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output

View File

@ -76,11 +76,14 @@ class TAFGEngineKernel(EngineKernel):
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
for ph in "ab":
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
@ -91,50 +94,76 @@ class TAFGEngineKernel(EngineKernel):
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"]
)
if self.config.loss.perceptual.weight > 0:
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch["a"]["img"], generated["images"]["a2b"]
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
self.generators["main"].style_converters.requires_grad_(False)
self.generators["main"].style_encoders.requires_grad_(False)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
loss["gan_a2b"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["random_b"], generated["styles"]["recon_b"]
)
for ph in "ab":
if self.config.loss.perceptual.weight > 0:
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch["a"]["img"], generated["images"]["a2b"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
)
if self.config.loss.style.weight > 0:
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
batch[ph]["img"], generated["images"][f"a2{ph}"]
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"][f"cycle_a"]
)
if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
)
# for ph in "ab":
#
# if self.config.loss.style.weight > 0:
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
# batch[ph]["img"], generated["images"][f"a2{ph}"]
# )
if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
else:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
@ -145,18 +174,30 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["a2b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()]
)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["a2b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()]
)
else:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
]
)
def change_engine(self, config, trainer):
pass

View File

@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
fake = dict(
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
b=self.generators["main"](content_img=batch["a"])
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
for phase in "b":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
return loss
def criterion_discriminators(self, batch, generated) -> dict:

View File

@ -189,34 +189,33 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(len(image_list)):
test_images[k].append([])
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed)
random_start = \
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
for i in range(random_start, random_start + config.handler.test.images):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
g = torch.Generator()
g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed)
random_start = \
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
for i in range(random_start, random_start + config.handler.test.images):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
generated = kernel.forward(batch, inference=True)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
return trainer

153
engine/talking_anime.py Normal file
View File

@ -0,0 +1,153 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.context_loss import ContextLoss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
context_loss_cfg = OmegaConf.to_container(config.loss.context)
context_loss_cfg.pop("weight")
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def build_models(self) -> (dict, dict):
generators = dict(
anime=build_model(self.config.model.anime_generator),
face=build_model(self.config.model.face_generator)
)
discriminators = dict(
anime=build_model(self.config.model.discriminator),
face=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["face"])
self.logger.debug(generators["face"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
target_pose_anime = self.generators["anime"](
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
pred_fake = discriminator(generated_img)
loss_gan = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if match_img is None:
# do not cal feature match loss
return loss_gan, 0
pred_real = discriminator(match_img)
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss_fm = self.config.loss.fm.weight * loss_fm
return loss_gan, loss_fm
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["face"], generated["fake_face"], batch["face_1"])
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
)
if self.config.loss.perceptual.weight > 0:
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
generated["fake_anime"], batch["anime_img"]
)
if self.config.loss.context.weight > 0:
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
generated["fake_anime"], batch["anime_img"],
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
real = {"anime": "anime_img", "face": "face_1"}
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[real[phase]])
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
generated["fake_face"].detach()]
return dict(
b=[img for img in images]
)
def run(task, config, _):
kernel = TAEngineKernel(config)
run_kernel(task, config, kernel)

44
loss/I2I/context_loss.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from .perceptual_loss import PerceptualVGG
class ContextLoss(nn.Module):
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
eps=1e-5):
super(ContextLoss, self).__init__()
self.eps = eps
self.h = h
self.layer_weights = layer_weights
self.norm_img = norm_img
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
def single_forward(self, source_feature, target_feature):
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
source_feature = F.normalize(source_feature, p=2, dim=1)
target_feature = F.normalize(target_feature, p=2, dim=1)
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
w = torch.exp((1 - rel_distance) / self.h)
cx = w.div(w.sum(dim=2, keepdim=True))
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
return -torch.log(cx).mean()
def forward(self, x, gt):
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
loss = 0
for k in x_features.keys():
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
return loss

View File

@ -4,6 +4,49 @@ import torch.nn.functional as F
import torchvision.models.vgg as vgg
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (17): ReLU(inplace=True)
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (24): ReLU(inplace=True)
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (26): ReLU(inplace=True)
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (29): ReLU(inplace=True)
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (31): ReLU(inplace=True)
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (33): ReLU(inplace=True)
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (35): ReLU(inplace=True)
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
norm_image_with_imagenet_param (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.use_input_norm = norm_image_with_imagenet_param
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module):
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module):
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)

View File

@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1)
@MODEL.register_module("TAFG-ResGenerator")
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
resnet_channels = 2 ** 2 * base_channels
self.decoder = Decoder(resnet_channels, out_channels, 2,
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
def forward(self, x):
return self.decoder(self.content_encoder(x))
@MODEL.register_module("TAFG-SingleGenerator")
class SingleGenerator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoder = StyleEncoder(
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
)
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
)
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
n_blocks=3, norm_type="NONE")
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
self.decoder = Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
def forward(self, content_img, style_img):
content = self.content_encoder(content_img)
style = self.style_encoder(style_img)
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,

View File

@ -3,45 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
use_spectral=True):
super().__init__()
self.padding_mode = padding_mode
self.use_bias = use_bias
self.use_spectral = use_spectral
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
self.use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.main = nn.Sequential(
self.conv_block(in_channels, in_channels),
norm_layer(in_channels),
nn.LeakyReLU(0.2, inplace=True),
self.conv_block(in_channels, out_channels),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
self.skip = nn.Sequential(
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
norm_layer(out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
padding_mode=self.padding_mode, bias=self.use_bias)
if self.use_spectral:
return nn.utils.spectral_norm(conv)
else:
return conv
def forward(self, x):
return self.main(x) + self.skip(x)
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
class Interpolation(nn.Module):
@ -59,106 +21,41 @@ class Interpolation(nn.Module):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.bn(x)
return alpha * x + beta
class FADEResBlock(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
)
self.skip = nn.Sequential(
FADE(use_spectral, features_channels, in_channels),
nn.LeakyReLU(0.2, inplace=True),
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
)
self.up_sample = Interpolation(2, mode="nearest")
@staticmethod
def forward_with_fade(module, x, feature):
for layer in module:
if layer.__class__.__name__ == "FADE":
x = layer(x, feature)
else:
x = layer(x)
return x
def forward(self, x, feature):
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
return self.up_sample(out)
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
return nn.utils.spectral_norm(conv) if use_spectral else conv
@MODEL.register_module("TSIT-Generator")
class TSITGenerator(nn.Module):
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
class Generator(nn.Module):
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
padding_mode="reflect", activation_type="ReLU"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
nn.Tanh()
)
self.content_stream = self.build_stream(padding_mode, activation_type)
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
def build_generator(self):
stream_sequence = []
sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels)
sequence.append(nn.Sequential(
ReverseResidualBlock(
multiple_prev * base_channels, multiple_now * base_channels,
padding_mode=padding_mode, norm_type="FADE",
additional_norm_kwargs=dict(
condition_in_channels=multiple_prev * base_channels,
base_norm_type="BN",
padding_mode=padding_mode
)
),
Interpolation(2, mode="nearest")
))
return nn.ModuleList(stream_sequence)
self.generator = nn.Sequential(*sequence)
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
kernel_size=7, padding_mode=padding_mode, padding=3)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
if input_layer_type == "conv7x7":
return nn.Sequential(
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
padding_mode="zeros", padding=3, bias=True),
select_norm_layer("IN")(out_channels),
nn.ReLU(inplace=True)
)
elif input_layer_type == "conv1x1":
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
raise NotImplemented
def build_stream(self):
def build_stream(self, padding_mode, activation_type):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
@ -166,27 +63,26 @@ class TSITGenerator(nn.Module):
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
use_spectral=self.use_spectral)
ResidualBlock(
multiple_prev * self.base_channels, multiple_now * self.base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img):
c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
def forward(self, content, z=None):
c = self.start_conv(content)
content_features = []
style_features = []
for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c)
content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
if z is None:
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
z = layer[0](z)
z = layer[1](z, content_features[m])
return self.end_conv(z)
res_block = self.generator[i][0]
res_block.conv1.normalization.set_feature(content_features[m])
res_block.conv2.normalization.set_feature(content_features[m])
if res_block.learn_skip_connection:
res_block.res_conv.normalization.set_feature(content_features[m])
return self.end_conv(self.generator(z))

View File

@ -1,8 +1,10 @@
from model.registry import MODEL
from model.registry import MODEL, NORMALIZATION
import model.GAN.CycleGAN
import model.GAN.MUNIT
import model.GAN.TAFG
import model.GAN.UGATIT
import model.GAN.wrapper
import model.GAN.base
import model.GAN.TSIT
import model.GAN.MUNIT
import model.GAN.UGATIT
import model.GAN.base
import model.GAN.wrapper
import model.base.normalization

0
model/base/__init__.py Normal file
View File

109
model/base/module.py Normal file
View File

@ -0,0 +1,109 @@
import torch.nn as nn
from model.registry import NORMALIZATION
_DO_NO_THING_FUNC = lambda x: x
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
def _normalization(norm, num_features, additional_kwargs=None):
if norm == "NONE":
return _DO_NO_THING_FUNC
if additional_kwargs is None:
additional_kwargs = {}
kwargs = dict(_type=norm, num_features=num_features)
kwargs.update(additional_kwargs)
return NORMALIZATION.build_with(kwargs)
def _activation(activation):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=True)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplemented(activation)
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = num_channels
if out_activation_type is None:
out_activation_type = "NONE"
self.learn_skip_connection = num_channels != out_channels
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
return self.convolution(self.activation(self.normalization(x)))
class ReverseResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode="reflect",
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
super().__init__()
self.learn_skip_connection = in_channels != out_channels
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
if self.learn_skip_connection:
self.res_conv = ReverseConv2dBlock(
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res

142
model/base/normalization.py Normal file
View File

@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import NORMALIZATION
from model.base.module import Conv2dBlock
_VALID_NORM_AND_ABBREVIATION = dict(
IN="InstanceNorm2d",
BN="BatchNorm2d",
)
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN"):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_condition = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
def __init__(self, num_features: int):
super().__init__(num_features, "IN")
self.num_features = num_features
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
gamma, beta = style.chunk(2, 1)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
def set_feature(self, feature):
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
def set_condition_image(self, condition_image):
feature = self.base_conv_block(condition_image)
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma + beta
return out
@NORMALIZATION.register_module("ILN")
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.gamma = nn.Parameter(torch.Tensor(num_features))
self.beta = nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.rho)
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
@NORMALIZATION.register_module("AdaILN")
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.rho.data.fill_(default_rho)
self.gamma = None
self.beta = None
self.have_set_condition = False
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
self.have_set_condition = False
return out

View File

@ -6,7 +6,7 @@ import torch.nn as nn
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
return functools.partial(nn.BatchNorm2d)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":

View File

@ -1,3 +1,4 @@
from util.registry import Registry
MODEL = Registry("model")
NORMALIZATION = Registry("normalization")

14
tool/inspect_model.py Normal file
View File

@ -0,0 +1,14 @@
import sys
import torch
from omegaconf import OmegaConf
from engine.util.build import build_model
config = OmegaConf.load(sys.argv[1])
generator = build_model(config.model.generator)
ckp = torch.load(sys.argv[2], map_location="cpu")
generator.module.load_state_dict(ckp["generator_main"])

View File

@ -0,0 +1,13 @@
from pathlib import Path
import sys
from collections import defaultdict
from itertools import permutations
pids = defaultdict(list)
for p in Path(sys.argv[1]).glob("*.jpg"):
pids[p.stem[:7]].append(p.stem)
data = []
for p in pids:
data.extend(list(permutations(pids[p], 2)))

View File

@ -2,6 +2,15 @@ import logging
from pathlib import Path
from typing import Optional
import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def setup_logger(
name: Optional[str] = None,