working
This commit is contained in:
parent
fbea96f6d7
commit
acf243cb12
@ -19,6 +19,7 @@ handler:
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
add_new_loss_epoch: -1
|
||||
|
||||
model:
|
||||
generator:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: self2anime-TSIT
|
||||
name: VoxCeleb2Anime-TSIT
|
||||
engine: TSIT
|
||||
result_dir: ./result
|
||||
max_pairs: 1500000
|
||||
@ -11,7 +11,10 @@ handler:
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
|
||||
misc:
|
||||
@ -86,24 +89,23 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDatasetWithEdge
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
size: [ 128, 128 ]
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@ -118,13 +120,14 @@ data:
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
with_path: True
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
|
||||
171
configs/synthesizers/talking_anime.yml
Normal file
171
configs/synthesizers/talking_anime.yml
Normal file
@ -0,0 +1,171 @@
|
||||
name: talking_anime
|
||||
engine: talking_anime
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 100 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
fm:
|
||||
level: 1
|
||||
weight: 1
|
||||
style:
|
||||
layer_weights:
|
||||
"3": 1
|
||||
criterion: 'L1'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 10
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0
|
||||
context:
|
||||
layer_weights:
|
||||
#"13": 1
|
||||
"22": 1
|
||||
weight: 5
|
||||
recon:
|
||||
level: 1
|
||||
weight: 10
|
||||
edge:
|
||||
weight: 5
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
model:
|
||||
face_generator:
|
||||
_type: TAFG-SingleGenerator
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
use_spectral_norm: True
|
||||
style_encoder_type: VGG19StyleEncoder
|
||||
num_style_conv: 4
|
||||
style_dim: 512
|
||||
num_adain_blocks: 8
|
||||
num_res_blocks: 8
|
||||
anime_generator:
|
||||
_type: TAFG-ResGenerator
|
||||
_bn_to_sync_bn: False
|
||||
in_channels: 6
|
||||
use_spectral_norm: True
|
||||
num_res_blocks: 8
|
||||
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 1
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 144, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@ -76,7 +76,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||
for ph in "ab":
|
||||
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
|
||||
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||
images["a2b"], "b", "b")
|
||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
@ -91,16 +94,29 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
|
||||
loss[f"gan_{ph}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
|
||||
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
|
||||
self.generators["main"].style_converters.requires_grad_(False)
|
||||
self.generators["main"].style_encoders.requires_grad_(False)
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
|
||||
loss["gan_a2b"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
|
||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
||||
generated["styles"]["random_b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
@ -108,16 +124,18 @@ class TAFGEngineKernel(EngineKernel):
|
||||
batch["a"]["img"], generated["images"]["a2b"]
|
||||
)
|
||||
|
||||
for ph in "ab":
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
|
||||
)
|
||||
if self.config.loss.style.weight > 0:
|
||||
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch["a"]["img"], generated["images"][f"cycle_a"]
|
||||
)
|
||||
|
||||
# for ph in "ab":
|
||||
#
|
||||
# if self.config.loss.style.weight > 0:
|
||||
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
# batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
# )
|
||||
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
@ -127,10 +145,21 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
# batch = self._process_batch(batch)
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
|
||||
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
|
||||
else:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
@ -145,6 +174,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
batch = self._process_batch(batch)
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
@ -157,6 +187,17 @@ class TAFGEngineKernel(EngineKernel):
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["cycle_b"].detach()]
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
pass
|
||||
|
||||
@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
||||
b=self.generators["main"](content_img=batch["a"])
|
||||
)
|
||||
return fake
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
|
||||
for phase in "b":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
@ -189,7 +189,6 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for i in range(len(image_list)):
|
||||
test_images[k].append([])
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
@ -205,7 +204,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
generated = kernel.forward(batch, inference=True)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
|
||||
153
engine/talking_anime.py
Normal file
153
engine/talking_anime.py
Normal file
@ -0,0 +1,153 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.context_loss import ContextLoss
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
context_loss_cfg = OmegaConf.to_container(config.loss.context)
|
||||
context_loss_cfg.pop("weight")
|
||||
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
|
||||
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
anime=build_model(self.config.model.anime_generator),
|
||||
face=build_model(self.config.model.face_generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
anime=build_model(self.config.model.discriminator),
|
||||
face=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["face"])
|
||||
self.logger.debug(generators["face"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
target_pose_anime = self.generators["anime"](
|
||||
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
|
||||
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
|
||||
|
||||
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
|
||||
|
||||
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
|
||||
pred_fake = discriminator(generated_img)
|
||||
loss_gan = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if match_img is None:
|
||||
# do not cal feature match loss
|
||||
return loss_gan, 0
|
||||
|
||||
pred_real = discriminator(match_img)
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss_fm = self.config.loss.fm.weight * loss_fm
|
||||
return loss_gan, loss_fm
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["face"], generated["fake_face"], batch["face_1"])
|
||||
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
|
||||
|
||||
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
|
||||
)
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
generated["fake_anime"], batch["anime_img"]
|
||||
)
|
||||
if self.config.loss.context.weight > 0:
|
||||
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
|
||||
generated["fake_anime"], batch["anime_img"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
real = {"anime": "anime_img", "face": "face_1"}
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[real[phase]])
|
||||
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
|
||||
generated["fake_face"].detach()]
|
||||
return dict(
|
||||
b=[img for img in images]
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-ResGenerator")
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
||||
super().__init__()
|
||||
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.content_encoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-SingleGenerator")
|
||||
class SingleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoder = StyleEncoder(
|
||||
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
||||
n_blocks=3, norm_type="NONE")
|
||||
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
content = self.content_encoder(content_img)
|
||||
style = self.style_encoder(style_img)
|
||||
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return self.decoder(content)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
|
||||
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
from model.normalization import AdaptiveInstanceNorm2d
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
@ -62,7 +61,9 @@ class Interpolation(nn.Module):
|
||||
class FADE(nn.Module):
|
||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
self.norm = nn.InstanceNorm2d(num_features=in_channels)
|
||||
|
||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
padding_mode="zeros")
|
||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
@ -71,7 +72,7 @@ class FADE(nn.Module):
|
||||
def forward(self, x, feature):
|
||||
alpha = self.alpha_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
x = self.bn(x)
|
||||
x = self.norm(x)
|
||||
return alpha * x + beta
|
||||
|
||||
|
||||
@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
|
||||
self.use_spectral = use_spectral
|
||||
|
||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
||||
self.content_stream = self.build_stream()
|
||||
self.style_stream = self.build_stream()
|
||||
self.generator = self.build_generator()
|
||||
self.end_conv = nn.Sequential(
|
||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||
@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
|
||||
m = self.num_blocks - i
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** m, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
||||
stream_sequence.append(
|
||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||
multiple_now * self.base_channels)
|
||||
))
|
||||
multiple_now * self.base_channels))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||
@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
def forward(self, content_img):
|
||||
c = self.content_input_layer(content_img)
|
||||
s = self.style_input_layer(style_img)
|
||||
content_features = []
|
||||
style_features = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.style_stream[i](s)
|
||||
c = self.content_stream[i](c)
|
||||
content_features.append(c)
|
||||
style_features.append(s)
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
m = - i - 1
|
||||
layer = self.generator[i]
|
||||
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
||||
z = layer[0](z)
|
||||
z = layer[1](z, content_features[m])
|
||||
z = layer(z, content_features[m])
|
||||
return self.end_conv(z)
|
||||
|
||||
14
tool/inspect_model.py
Normal file
14
tool/inspect_model.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.util.build import build_model
|
||||
|
||||
config = OmegaConf.load(sys.argv[1])
|
||||
|
||||
|
||||
generator = build_model(config.model.generator)
|
||||
|
||||
ckp = torch.load(sys.argv[2], map_location="cpu")
|
||||
|
||||
generator.module.load_state_dict(ckp["generator_main"])
|
||||
13
tool/process/permutation_face.py
Normal file
13
tool/process/permutation_face.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from itertools import permutations
|
||||
|
||||
pids = defaultdict(list)
|
||||
for p in Path(sys.argv[1]).glob("*.jpg"):
|
||||
pids[p.stem[:7]].append(p.stem)
|
||||
|
||||
data = []
|
||||
for p in pids:
|
||||
data.extend(list(permutations(pids[p], 2)))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user