This commit is contained in:
Ray Wong 2020-09-25 18:31:12 +08:00
parent fbea96f6d7
commit acf243cb12
11 changed files with 542 additions and 115 deletions

View File

@ -19,6 +19,7 @@ handler:
misc: misc:
random_seed: 1004 random_seed: 1004
add_new_loss_epoch: -1
model: model:
generator: generator:

View File

@ -1,4 +1,4 @@
name: self2anime-TSIT name: VoxCeleb2Anime-TSIT
engine: TSIT engine: TSIT
result_dir: ./result result_dir: ./result
max_pairs: 1500000 max_pairs: 1500000
@ -11,7 +11,10 @@ handler:
n_saved: 2 n_saved: 2
tensorboard: tensorboard:
scalar: 100 # log scalar `scalar` times per epoch scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch image: 4 # log image `image` times per epoch
test:
random: True
images: 10
misc: misc:
@ -86,24 +89,23 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 1 batch_size: 8
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True
drop_last: True drop_last: True
dataset: dataset:
_type: GenerationUnpairedDatasetWithEdge _type: GenerationUnpairedDataset
root_a: "/data/i2i/VoxCeleb2Anime/trainA" root_a: "/data/i2i/faces/CelebA-Asian/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB" root_b: "/data/i2i/anime/your-name/faces"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
size: [ 128, 128 ]
random_pair: True random_pair: True
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ] size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor - ToTensor
- Normalize: - Normalize:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]
@ -118,13 +120,14 @@ data:
drop_last: False drop_last: False
dataset: dataset:
_type: GenerationUnpairedDataset _type: GenerationUnpairedDataset
root_a: "/data/i2i/VoxCeleb2Anime/testA" root_a: "/data/i2i/faces/CelebA-Asian/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB" root_b: "/data/i2i/anime/your-name/faces"
with_path: True
random_pair: False random_pair: False
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ] size: [ 128, 128 ]
- ToTensor - ToTensor
- Normalize: - Normalize:

View File

@ -0,0 +1,171 @@
name: talking_anime
engine: talking_anime
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 100 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 1004
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
fm:
level: 1
weight: 1
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
context:
layer_weights:
#"13": 1
"22": 1
weight: 5
recon:
level: 1
weight: 10
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
model:
face_generator:
_type: TAFG-SingleGenerator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
use_spectral_norm: True
style_encoder_type: VGG19StyleEncoder
num_style_conv: 4
style_dim: 512
num_adain_blocks: 8
num_res_blocks: 8
anime_generator:
_type: TAFG-ResGenerator
_bn_to_sync_bn: False
in_channels: 6
use_spectral_norm: True
num_res_blocks: 8
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
dataloader:
batch_size: 8
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 144, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/testA"
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@ -76,7 +76,10 @@ class TAFGEngineKernel(EngineKernel):
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b") contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
for ph in "ab": for ph in "ab":
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph) images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]), contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b") images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b") images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
@ -91,16 +94,29 @@ class TAFGEngineKernel(EngineKernel):
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss( loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"]) generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"]) pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
loss[f"gan_{ph}"] = 0 loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake: for sub_pred_fake in pred_fake:
# last output is actual prediction # last output is actual prediction
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
self.generators["main"].style_converters.requires_grad_(False)
self.generators["main"].style_encoders.requires_grad_(False)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
loss["gan_a2b"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss( loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"] generated["contents"]["a"], generated["contents"]["recon_a"]
) )
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss( loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"] generated["styles"]["random_b"], generated["styles"]["recon_b"]
) )
if self.config.loss.perceptual.weight > 0: if self.config.loss.perceptual.weight > 0:
@ -108,16 +124,18 @@ class TAFGEngineKernel(EngineKernel):
batch["a"]["img"], generated["images"]["a2b"] batch["a"]["img"], generated["images"]["a2b"]
) )
for ph in "ab":
if self.config.loss.cycle.weight > 0: if self.config.loss.cycle.weight > 0:
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss( loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch[ph]["img"], generated["images"][f"cycle_{ph}"] batch["a"]["img"], generated["images"][f"cycle_a"]
)
if self.config.loss.style.weight > 0:
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
batch[ph]["img"], generated["images"][f"a2{ph}"]
) )
# for ph in "ab":
#
# if self.config.loss.style.weight > 0:
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
# batch[ph]["img"], generated["images"][f"a2{ph}"]
# )
if self.config.loss.edge.weight > 0: if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss( loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :] generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
@ -127,10 +145,21 @@ class TAFGEngineKernel(EngineKernel):
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:
loss = dict() loss = dict()
# batch = self._process_batch(batch)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
for phase in self.discriminators.keys(): for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"]) pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach()) pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
else:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)): for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
@ -145,6 +174,7 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
""" """
batch = self._process_batch(batch) batch = self._process_batch(batch)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
return dict( return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(), a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(), batch["a"]["img"].detach(),
@ -157,6 +187,17 @@ class TAFGEngineKernel(EngineKernel):
generated["images"]["b2b"].detach(), generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()] generated["images"]["cycle_b"].detach()]
) )
else:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
]
)
def change_engine(self, config, trainer): def change_engine(self, config, trainer):
pass pass

View File

@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict: def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference): with torch.set_grad_enabled(not inference):
fake = dict( fake = dict(
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"]) b=self.generators["main"](content_img=batch["a"])
) )
return fake return fake
def criterion_generators(self, batch, generated) -> dict: def criterion_generators(self, batch, generated) -> dict:
loss = dict() loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "b": for phase in "b":
pred_fake = self.discriminators[phase](generated[phase]) pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake: for sub_pred_fake in pred_fake:
# last output is actual prediction # last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True) loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:

View File

@ -189,7 +189,6 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(len(image_list)): for i in range(len(image_list)):
test_images[k].append([]) test_images[k].append([])
with torch.no_grad():
g = torch.Generator() g = torch.Generator()
g.manual_seed(config.misc.random_seed + engine.state.epoch g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed) if config.handler.test.random else config.misc.random_seed)
@ -205,7 +204,7 @@ def get_trainer(config, kernel: EngineKernel):
if isinstance(batch[k][kk], torch.Tensor): if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0) batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch) generated = kernel.forward(batch, inference=True)
images = kernel.intermediate_images(batch, generated) images = kernel.intermediate_images(batch, generated)
for k in test_images: for k in test_images:

153
engine/talking_anime.py Normal file
View File

@ -0,0 +1,153 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.context_loss import ContextLoss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
context_loss_cfg = OmegaConf.to_container(config.loss.context)
context_loss_cfg.pop("weight")
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def build_models(self) -> (dict, dict):
generators = dict(
anime=build_model(self.config.model.anime_generator),
face=build_model(self.config.model.face_generator)
)
discriminators = dict(
anime=build_model(self.config.model.discriminator),
face=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["face"])
self.logger.debug(generators["face"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
target_pose_anime = self.generators["anime"](
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
pred_fake = discriminator(generated_img)
loss_gan = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if match_img is None:
# do not cal feature match loss
return loss_gan, 0
pred_real = discriminator(match_img)
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss_fm = self.config.loss.fm.weight * loss_fm
return loss_gan, loss_fm
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["face"], generated["fake_face"], batch["face_1"])
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
)
if self.config.loss.perceptual.weight > 0:
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
generated["fake_anime"], batch["anime_img"]
)
if self.config.loss.context.weight > 0:
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
generated["fake_anime"], batch["anime_img"],
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
real = {"anime": "anime_img", "face": "face_1"}
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[real[phase]])
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
generated["fake_face"].detach()]
return dict(
b=[img for img in images]
)
def run(task, config, _):
kernel = TAEngineKernel(config)
run_kernel(task, config, kernel)

View File

@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
@MODEL.register_module("TAFG-ResGenerator")
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
resnet_channels = 2 ** 2 * base_channels
self.decoder = Decoder(resnet_channels, out_channels, 2,
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
def forward(self, x):
return self.decoder(self.content_encoder(x))
@MODEL.register_module("TAFG-SingleGenerator")
class SingleGenerator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoder = StyleEncoder(
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
)
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
)
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
n_blocks=3, norm_type="NONE")
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
self.decoder = Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
def forward(self, content_img, style_img):
content = self.content_encoder(content_img)
style = self.style_encoder(style_img)
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
@MODEL.register_module("TAFG-Generator") @MODEL.register_module("TAFG-Generator")
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False, def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,

View File

@ -3,7 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model import MODEL from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
@ -62,7 +61,9 @@ class Interpolation(nn.Module):
class FADE(nn.Module): class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True): def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__() super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats) # self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.norm = nn.InstanceNorm2d(num_features=in_channels)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros") padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1, self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
@ -71,7 +72,7 @@ class FADE(nn.Module):
def forward(self, x, feature): def forward(self, x, feature):
alpha = self.alpha_conv(feature) alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature) beta = self.beta_conv(feature)
x = self.bn(x) x = self.norm(x)
return alpha * x + beta return alpha * x + beta
@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
self.use_spectral = use_spectral self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type) self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream() self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator() self.generator = self.build_generator()
self.end_conv = nn.Sequential( self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"), conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
m = self.num_blocks - i m = self.num_blocks - i
multiple_prev = multiple_now multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4) multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential( stream_sequence.append(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels, FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels) multiple_now * self.base_channels))
))
return nn.ModuleList(stream_sequence) return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"): def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
)) ))
return nn.ModuleList(stream_sequence) return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img): def forward(self, content_img):
c = self.content_input_layer(content_img) c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
content_features = [] content_features = []
style_features = []
for i in range(self.num_blocks): for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c) c = self.content_stream[i](c)
content_features.append(c) content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device) z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks): for i in range(self.num_blocks):
m = - i - 1 m = - i - 1
layer = self.generator[i] layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1)) z = layer(z, content_features[m])
z = layer[0](z)
z = layer[1](z, content_features[m])
return self.end_conv(z) return self.end_conv(z)

14
tool/inspect_model.py Normal file
View File

@ -0,0 +1,14 @@
import sys
import torch
from omegaconf import OmegaConf
from engine.util.build import build_model
config = OmegaConf.load(sys.argv[1])
generator = build_model(config.model.generator)
ckp = torch.load(sys.argv[2], map_location="cpu")
generator.module.load_state_dict(ckp["generator_main"])

View File

@ -0,0 +1,13 @@
from pathlib import Path
import sys
from collections import defaultdict
from itertools import permutations
pids = defaultdict(list)
for p in Path(sys.argv[1]).glob("*.jpg"):
pids[p.stem[:7]].append(p.stem)
data = []
for p in pids:
data.extend(list(permutations(pids[p], 2)))