working
This commit is contained in:
parent
fbea96f6d7
commit
acf243cb12
@ -19,6 +19,7 @@ handler:
|
|||||||
|
|
||||||
misc:
|
misc:
|
||||||
random_seed: 1004
|
random_seed: 1004
|
||||||
|
add_new_loss_epoch: -1
|
||||||
|
|
||||||
model:
|
model:
|
||||||
generator:
|
generator:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
name: self2anime-TSIT
|
name: VoxCeleb2Anime-TSIT
|
||||||
engine: TSIT
|
engine: TSIT
|
||||||
result_dir: ./result
|
result_dir: ./result
|
||||||
max_pairs: 1500000
|
max_pairs: 1500000
|
||||||
@ -11,7 +11,10 @@ handler:
|
|||||||
n_saved: 2
|
n_saved: 2
|
||||||
tensorboard:
|
tensorboard:
|
||||||
scalar: 100 # log scalar `scalar` times per epoch
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
image: 2 # log image `image` times per epoch
|
image: 4 # log image `image` times per epoch
|
||||||
|
test:
|
||||||
|
random: True
|
||||||
|
images: 10
|
||||||
|
|
||||||
|
|
||||||
misc:
|
misc:
|
||||||
@ -86,24 +89,23 @@ data:
|
|||||||
target_lr: 0
|
target_lr: 0
|
||||||
buffer_size: 50
|
buffer_size: 50
|
||||||
dataloader:
|
dataloader:
|
||||||
batch_size: 1
|
batch_size: 8
|
||||||
shuffle: True
|
shuffle: True
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
pin_memory: True
|
pin_memory: True
|
||||||
drop_last: True
|
drop_last: True
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDatasetWithEdge
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
root_b: "/data/i2i/anime/your-name/faces"
|
||||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
|
||||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
|
||||||
edge_type: "landmark_hed"
|
|
||||||
size: [ 128, 128 ]
|
|
||||||
random_pair: True
|
random_pair: True
|
||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
|
size: [ 170, 144 ]
|
||||||
|
- RandomCrop:
|
||||||
size: [ 128, 128 ]
|
size: [ 128, 128 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
- ToTensor
|
- ToTensor
|
||||||
- Normalize:
|
- Normalize:
|
||||||
mean: [ 0.5, 0.5, 0.5 ]
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
@ -118,13 +120,14 @@ data:
|
|||||||
drop_last: False
|
drop_last: False
|
||||||
dataset:
|
dataset:
|
||||||
_type: GenerationUnpairedDataset
|
_type: GenerationUnpairedDataset
|
||||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
||||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
root_b: "/data/i2i/anime/your-name/faces"
|
||||||
with_path: True
|
|
||||||
random_pair: False
|
random_pair: False
|
||||||
pipeline:
|
pipeline:
|
||||||
- Load
|
- Load
|
||||||
- Resize:
|
- Resize:
|
||||||
|
size: [ 170, 144 ]
|
||||||
|
- RandomCrop:
|
||||||
size: [ 128, 128 ]
|
size: [ 128, 128 ]
|
||||||
- ToTensor
|
- ToTensor
|
||||||
- Normalize:
|
- Normalize:
|
||||||
|
|||||||
171
configs/synthesizers/talking_anime.yml
Normal file
171
configs/synthesizers/talking_anime.yml
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
name: talking_anime
|
||||||
|
engine: talking_anime
|
||||||
|
result_dir: ./result
|
||||||
|
max_pairs: 1000000
|
||||||
|
|
||||||
|
handler:
|
||||||
|
clear_cuda_cache: True
|
||||||
|
set_epoch_for_dist_sampler: True
|
||||||
|
checkpoint:
|
||||||
|
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||||
|
n_saved: 2
|
||||||
|
tensorboard:
|
||||||
|
scalar: 100 # log scalar `scalar` times per epoch
|
||||||
|
image: 100 # log image `image` times per epoch
|
||||||
|
test:
|
||||||
|
random: True
|
||||||
|
images: 10
|
||||||
|
|
||||||
|
misc:
|
||||||
|
random_seed: 1004
|
||||||
|
|
||||||
|
loss:
|
||||||
|
gan:
|
||||||
|
loss_type: hinge
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
weight: 1.0
|
||||||
|
fm:
|
||||||
|
level: 1
|
||||||
|
weight: 1
|
||||||
|
style:
|
||||||
|
layer_weights:
|
||||||
|
"3": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: True
|
||||||
|
perceptual_loss: False
|
||||||
|
weight: 10
|
||||||
|
perceptual:
|
||||||
|
layer_weights:
|
||||||
|
"1": 0.03125
|
||||||
|
"6": 0.0625
|
||||||
|
"11": 0.125
|
||||||
|
"20": 0.25
|
||||||
|
"29": 1
|
||||||
|
criterion: 'L1'
|
||||||
|
style_loss: False
|
||||||
|
perceptual_loss: True
|
||||||
|
weight: 0
|
||||||
|
context:
|
||||||
|
layer_weights:
|
||||||
|
#"13": 1
|
||||||
|
"22": 1
|
||||||
|
weight: 5
|
||||||
|
recon:
|
||||||
|
level: 1
|
||||||
|
weight: 10
|
||||||
|
edge:
|
||||||
|
weight: 5
|
||||||
|
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||||
|
|
||||||
|
model:
|
||||||
|
face_generator:
|
||||||
|
_type: TAFG-SingleGenerator
|
||||||
|
_bn_to_sync_bn: False
|
||||||
|
style_in_channels: 3
|
||||||
|
content_in_channels: 1
|
||||||
|
use_spectral_norm: True
|
||||||
|
style_encoder_type: VGG19StyleEncoder
|
||||||
|
num_style_conv: 4
|
||||||
|
style_dim: 512
|
||||||
|
num_adain_blocks: 8
|
||||||
|
num_res_blocks: 8
|
||||||
|
anime_generator:
|
||||||
|
_type: TAFG-ResGenerator
|
||||||
|
_bn_to_sync_bn: False
|
||||||
|
in_channels: 6
|
||||||
|
use_spectral_norm: True
|
||||||
|
num_res_blocks: 8
|
||||||
|
|
||||||
|
discriminator:
|
||||||
|
_type: MultiScaleDiscriminator
|
||||||
|
num_scale: 2
|
||||||
|
discriminator_cfg:
|
||||||
|
_type: PatchDiscriminator
|
||||||
|
in_channels: 3
|
||||||
|
base_channels: 64
|
||||||
|
use_spectral: True
|
||||||
|
need_intermediate_feature: True
|
||||||
|
|
||||||
|
optimizers:
|
||||||
|
generator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 0.0001
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
discriminator:
|
||||||
|
_type: Adam
|
||||||
|
lr: 4e-4
|
||||||
|
betas: [ 0, 0.9 ]
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
scheduler:
|
||||||
|
start_proportion: 0.5
|
||||||
|
target_lr: 0
|
||||||
|
dataloader:
|
||||||
|
batch_size: 8
|
||||||
|
shuffle: True
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: True
|
||||||
|
drop_last: True
|
||||||
|
dataset:
|
||||||
|
_type: PoseFacesWithSingleAnime
|
||||||
|
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||||
|
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||||
|
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||||
|
num_face: 2
|
||||||
|
img_size: [ 128, 128 ]
|
||||||
|
with_order: False
|
||||||
|
face_pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
anime_pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 144, 144 ]
|
||||||
|
- RandomCrop:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- RandomHorizontalFlip
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
test:
|
||||||
|
which: dataset
|
||||||
|
dataloader:
|
||||||
|
batch_size: 1
|
||||||
|
shuffle: False
|
||||||
|
num_workers: 1
|
||||||
|
pin_memory: False
|
||||||
|
drop_last: False
|
||||||
|
dataset:
|
||||||
|
_type: PoseFacesWithSingleAnime
|
||||||
|
root_face: "/data/i2i/VoxCeleb2Anime/testA"
|
||||||
|
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
|
||||||
|
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||||
|
num_face: 2
|
||||||
|
img_size: [ 128, 128 ]
|
||||||
|
with_order: False
|
||||||
|
face_pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
|
anime_pipeline:
|
||||||
|
- Load
|
||||||
|
- Resize:
|
||||||
|
size: [ 128, 128 ]
|
||||||
|
- ToTensor
|
||||||
|
- Normalize:
|
||||||
|
mean: [ 0.5, 0.5, 0.5 ]
|
||||||
|
std: [ 0.5, 0.5, 0.5 ]
|
||||||
133
engine/TAFG.py
133
engine/TAFG.py
@ -76,11 +76,14 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||||
for ph in "ab":
|
for ph in "ab":
|
||||||
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||||
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
|
|
||||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||||
images["a2b"], "b", "b")
|
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
|
||||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
|
||||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||||
|
images["a2b"], "b", "b")
|
||||||
|
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||||
|
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||||
return dict(styles=styles, contents=contents, images=images)
|
return dict(styles=styles, contents=contents, images=images)
|
||||||
|
|
||||||
def criterion_generators(self, batch, generated) -> dict:
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
@ -91,50 +94,76 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||||
|
|
||||||
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
|
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
|
||||||
loss[f"gan_{ph}"] = 0
|
loss[f"gan_{ph}"] = 0
|
||||||
for sub_pred_fake in pred_fake:
|
for sub_pred_fake in pred_fake:
|
||||||
# last output is actual prediction
|
# last output is actual prediction
|
||||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
|
||||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
|
||||||
)
|
|
||||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
|
||||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.loss.perceptual.weight > 0:
|
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
|
||||||
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
self.generators["main"].style_converters.requires_grad_(False)
|
||||||
batch["a"]["img"], generated["images"]["a2b"]
|
self.generators["main"].style_encoders.requires_grad_(False)
|
||||||
|
|
||||||
|
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||||
|
|
||||||
|
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
|
||||||
|
loss["gan_a2b"] = 0
|
||||||
|
for sub_pred_fake in pred_fake:
|
||||||
|
# last output is actual prediction
|
||||||
|
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||||
|
|
||||||
|
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||||
|
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||||
|
)
|
||||||
|
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||||
|
generated["styles"]["random_b"], generated["styles"]["recon_b"]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ph in "ab":
|
if self.config.loss.perceptual.weight > 0:
|
||||||
|
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||||
|
batch["a"]["img"], generated["images"]["a2b"]
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.loss.cycle.weight > 0:
|
if self.config.loss.cycle.weight > 0:
|
||||||
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||||
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
|
batch["a"]["img"], generated["images"][f"cycle_a"]
|
||||||
)
|
|
||||||
if self.config.loss.style.weight > 0:
|
|
||||||
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
|
||||||
batch[ph]["img"], generated["images"][f"a2{ph}"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.loss.edge.weight > 0:
|
# for ph in "ab":
|
||||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
#
|
||||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
# if self.config.loss.style.weight > 0:
|
||||||
)
|
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||||
|
# batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||||
|
# )
|
||||||
|
|
||||||
|
if self.config.loss.edge.weight > 0:
|
||||||
|
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||||
|
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||||
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
loss = dict()
|
loss = dict()
|
||||||
# batch = self._process_batch(batch)
|
|
||||||
for phase in self.discriminators.keys():
|
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
for phase in self.discriminators.keys():
|
||||||
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
|
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||||
loss[f"gan_{phase}"] = 0
|
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||||
for i in range(len(pred_fake)):
|
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
|
||||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
loss[f"gan_{phase}"] = 0
|
||||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
|
||||||
|
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
|
||||||
|
else:
|
||||||
|
for phase in self.discriminators.keys():
|
||||||
|
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||||
|
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def intermediate_images(self, batch, generated) -> dict:
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
@ -145,18 +174,30 @@ class TAFGEngineKernel(EngineKernel):
|
|||||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
"""
|
"""
|
||||||
batch = self._process_batch(batch)
|
batch = self._process_batch(batch)
|
||||||
return dict(
|
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
return dict(
|
||||||
batch["a"]["img"].detach(),
|
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||||
generated["images"]["a2a"].detach(),
|
batch["a"]["img"].detach(),
|
||||||
generated["images"]["a2b"].detach(),
|
generated["images"]["a2a"].detach(),
|
||||||
generated["images"]["cycle_a"].detach(),
|
generated["images"]["a2b"].detach(),
|
||||||
],
|
generated["images"]["cycle_a"].detach(),
|
||||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
],
|
||||||
batch["b"]["img"].detach(),
|
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||||
generated["images"]["b2b"].detach(),
|
batch["b"]["img"].detach(),
|
||||||
generated["images"]["cycle_b"].detach()]
|
generated["images"]["b2b"].detach(),
|
||||||
)
|
generated["images"]["cycle_b"].detach()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return dict(
|
||||||
|
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||||
|
batch["a"]["img"].detach(),
|
||||||
|
generated["images"]["a2a"].detach(),
|
||||||
|
],
|
||||||
|
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||||
|
batch["b"]["img"].detach(),
|
||||||
|
generated["images"]["b2b"].detach(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def change_engine(self, config, trainer):
|
def change_engine(self, config, trainer):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
|
|||||||
def forward(self, batch, inference=False) -> dict:
|
def forward(self, batch, inference=False) -> dict:
|
||||||
with torch.set_grad_enabled(not inference):
|
with torch.set_grad_enabled(not inference):
|
||||||
fake = dict(
|
fake = dict(
|
||||||
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
b=self.generators["main"](content_img=batch["a"])
|
||||||
)
|
)
|
||||||
return fake
|
return fake
|
||||||
|
|
||||||
def criterion_generators(self, batch, generated) -> dict:
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
loss = dict()
|
loss = dict()
|
||||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
|
||||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
|
||||||
for phase in "b":
|
for phase in "b":
|
||||||
pred_fake = self.discriminators[phase](generated[phase])
|
pred_fake = self.discriminators[phase](generated[phase])
|
||||||
loss[f"gan_{phase}"] = 0
|
loss[f"gan_{phase}"] = 0
|
||||||
for sub_pred_fake in pred_fake:
|
for sub_pred_fake in pred_fake:
|
||||||
# last output is actual prediction
|
# last output is actual prediction
|
||||||
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
|
||||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
|
||||||
pred_real = self.discriminators[phase](batch[phase])
|
|
||||||
loss_fm = 0
|
|
||||||
num_scale_discriminator = len(pred_fake)
|
|
||||||
for i in range(num_scale_discriminator):
|
|
||||||
# last output is the final prediction, so we exclude it
|
|
||||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
|
||||||
for j in range(num_intermediate_outputs):
|
|
||||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
|
||||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def criterion_discriminators(self, batch, generated) -> dict:
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
|||||||
@ -189,34 +189,33 @@ def get_trainer(config, kernel: EngineKernel):
|
|||||||
for i in range(len(image_list)):
|
for i in range(len(image_list)):
|
||||||
test_images[k].append([])
|
test_images[k].append([])
|
||||||
|
|
||||||
with torch.no_grad():
|
g = torch.Generator()
|
||||||
g = torch.Generator()
|
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
if config.handler.test.random else config.misc.random_seed)
|
||||||
if config.handler.test.random else config.misc.random_seed)
|
random_start = \
|
||||||
random_start = \
|
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
for i in range(random_start, random_start + config.handler.test.images):
|
||||||
for i in range(random_start, random_start + config.handler.test.images):
|
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
for k in batch:
|
||||||
for k in batch:
|
if isinstance(batch[k], torch.Tensor):
|
||||||
if isinstance(batch[k], torch.Tensor):
|
batch[k] = batch[k].unsqueeze(0)
|
||||||
batch[k] = batch[k].unsqueeze(0)
|
elif isinstance(batch[k], dict):
|
||||||
elif isinstance(batch[k], dict):
|
for kk in batch[k]:
|
||||||
for kk in batch[k]:
|
if isinstance(batch[k][kk], torch.Tensor):
|
||||||
if isinstance(batch[k][kk], torch.Tensor):
|
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
|
||||||
|
|
||||||
generated = kernel.forward(batch)
|
generated = kernel.forward(batch, inference=True)
|
||||||
images = kernel.intermediate_images(batch, generated)
|
images = kernel.intermediate_images(batch, generated)
|
||||||
|
|
||||||
for k in test_images:
|
|
||||||
for j in range(len(images[k])):
|
|
||||||
test_images[k][j].append(images[k][j])
|
|
||||||
for k in test_images:
|
for k in test_images:
|
||||||
tensorboard_handler.writer.add_image(
|
for j in range(len(images[k])):
|
||||||
f"test/{k}",
|
test_images[k][j].append(images[k][j])
|
||||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
for k in test_images:
|
||||||
engine.state.iteration * pairs_per_iteration
|
tensorboard_handler.writer.add_image(
|
||||||
)
|
f"test/{k}",
|
||||||
|
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||||
|
engine.state.iteration * pairs_per_iteration
|
||||||
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
153
engine/talking_anime.py
Normal file
153
engine/talking_anime.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import ignite.distributed as idist
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from engine.base.i2i import EngineKernel, run_kernel
|
||||||
|
from engine.util.build import build_model
|
||||||
|
from loss.I2I.context_loss import ContextLoss
|
||||||
|
from loss.I2I.edge_loss import EdgeLoss
|
||||||
|
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||||
|
from loss.gan import GANLoss
|
||||||
|
from model.weight_init import generation_init_weights
|
||||||
|
|
||||||
|
|
||||||
|
class TAEngineKernel(EngineKernel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||||
|
perceptual_loss_cfg.pop("weight")
|
||||||
|
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||||
|
style_loss_cfg.pop("weight")
|
||||||
|
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||||
|
gan_loss_cfg.pop("weight")
|
||||||
|
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
context_loss_cfg = OmegaConf.to_container(config.loss.context)
|
||||||
|
context_loss_cfg.pop("weight")
|
||||||
|
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
|
||||||
|
|
||||||
|
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||||
|
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||||
|
|
||||||
|
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||||
|
idist.device())
|
||||||
|
|
||||||
|
def build_models(self) -> (dict, dict):
|
||||||
|
generators = dict(
|
||||||
|
anime=build_model(self.config.model.anime_generator),
|
||||||
|
face=build_model(self.config.model.face_generator)
|
||||||
|
)
|
||||||
|
discriminators = dict(
|
||||||
|
anime=build_model(self.config.model.discriminator),
|
||||||
|
face=build_model(self.config.model.discriminator)
|
||||||
|
)
|
||||||
|
self.logger.debug(discriminators["face"])
|
||||||
|
self.logger.debug(generators["face"])
|
||||||
|
|
||||||
|
for m in chain(generators.values(), discriminators.values()):
|
||||||
|
generation_init_weights(m)
|
||||||
|
|
||||||
|
return generators, discriminators
|
||||||
|
|
||||||
|
def setup_after_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(True)
|
||||||
|
|
||||||
|
def setup_before_g(self):
|
||||||
|
for discriminator in self.discriminators.values():
|
||||||
|
discriminator.requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, batch, inference=False) -> dict:
|
||||||
|
with torch.set_grad_enabled(not inference):
|
||||||
|
target_pose_anime = self.generators["anime"](
|
||||||
|
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
|
||||||
|
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
|
||||||
|
|
||||||
|
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
|
||||||
|
|
||||||
|
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
|
||||||
|
pred_fake = discriminator(generated_img)
|
||||||
|
loss_gan = 0
|
||||||
|
for sub_pred_fake in pred_fake:
|
||||||
|
# last output is actual prediction
|
||||||
|
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||||
|
|
||||||
|
if match_img is None:
|
||||||
|
# do not cal feature match loss
|
||||||
|
return loss_gan, 0
|
||||||
|
|
||||||
|
pred_real = discriminator(match_img)
|
||||||
|
loss_fm = 0
|
||||||
|
num_scale_discriminator = len(pred_fake)
|
||||||
|
for i in range(num_scale_discriminator):
|
||||||
|
# last output is the final prediction, so we exclude it
|
||||||
|
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||||
|
for j in range(num_intermediate_outputs):
|
||||||
|
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||||
|
loss_fm = self.config.loss.fm.weight * loss_fm
|
||||||
|
return loss_gan, loss_fm
|
||||||
|
|
||||||
|
def criterion_generators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
|
||||||
|
generated["fake_face"], batch["face_1"]
|
||||||
|
)
|
||||||
|
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
|
||||||
|
generated["fake_face"], batch["face_1"]
|
||||||
|
)
|
||||||
|
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
|
||||||
|
self.discriminators["face"], generated["fake_face"], batch["face_1"])
|
||||||
|
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
|
||||||
|
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
|
||||||
|
|
||||||
|
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
|
||||||
|
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
|
||||||
|
)
|
||||||
|
if self.config.loss.perceptual.weight > 0:
|
||||||
|
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||||
|
generated["fake_anime"], batch["anime_img"]
|
||||||
|
)
|
||||||
|
if self.config.loss.context.weight > 0:
|
||||||
|
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
|
||||||
|
generated["fake_anime"], batch["anime_img"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def criterion_discriminators(self, batch, generated) -> dict:
|
||||||
|
loss = dict()
|
||||||
|
real = {"anime": "anime_img", "face": "face_1"}
|
||||||
|
for phase in self.discriminators.keys():
|
||||||
|
pred_real = self.discriminators[phase](batch[real[phase]])
|
||||||
|
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
|
||||||
|
loss[f"gan_{phase}"] = 0
|
||||||
|
for i in range(len(pred_fake)):
|
||||||
|
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||||
|
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def intermediate_images(self, batch, generated) -> dict:
|
||||||
|
"""
|
||||||
|
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
:param batch:
|
||||||
|
:param generated: dict of images
|
||||||
|
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||||
|
"""
|
||||||
|
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
|
||||||
|
generated["fake_face"].detach()]
|
||||||
|
return dict(
|
||||||
|
b=[img for img in images]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(task, config, _):
|
||||||
|
kernel = TAEngineKernel(config)
|
||||||
|
run_kernel(task, config, kernel)
|
||||||
@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
|
|||||||
return x.view(x.size(0), -1)
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("TAFG-ResGenerator")
|
||||||
|
class ResGenerator(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
||||||
|
super().__init__()
|
||||||
|
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
||||||
|
use_spectral_norm=use_spectral_norm)
|
||||||
|
resnet_channels = 2 ** 2 * base_channels
|
||||||
|
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||||
|
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(self.content_encoder(x))
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL.register_module("TAFG-SingleGenerator")
|
||||||
|
class SingleGenerator(nn.Module):
|
||||||
|
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
||||||
|
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||||
|
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||||
|
super().__init__()
|
||||||
|
self.num_adain_blocks = num_adain_blocks
|
||||||
|
if style_encoder_type == "StyleEncoder":
|
||||||
|
self.style_encoder = StyleEncoder(
|
||||||
|
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||||
|
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
||||||
|
)
|
||||||
|
elif style_encoder_type == "VGG19StyleEncoder":
|
||||||
|
self.style_encoder = VGG19StyleEncoder(
|
||||||
|
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||||
|
|
||||||
|
resnet_channels = 2 ** 2 * base_channels
|
||||||
|
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
||||||
|
n_blocks=3, norm_type="NONE")
|
||||||
|
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
||||||
|
use_spectral_norm=use_spectral_norm)
|
||||||
|
|
||||||
|
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||||
|
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
||||||
|
|
||||||
|
def forward(self, content_img, style_img):
|
||||||
|
content = self.content_encoder(content_img)
|
||||||
|
style = self.style_encoder(style_img)
|
||||||
|
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
||||||
|
# set style for decoder
|
||||||
|
for i, blk in enumerate(self.decoder.res_blocks):
|
||||||
|
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||||
|
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||||
|
return self.decoder(content)
|
||||||
|
|
||||||
|
|
||||||
@MODEL.register_module("TAFG-Generator")
|
@MODEL.register_module("TAFG-Generator")
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from model import MODEL
|
from model import MODEL
|
||||||
from model.normalization import AdaptiveInstanceNorm2d
|
|
||||||
from model.normalization import select_norm_layer
|
from model.normalization import select_norm_layer
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +61,9 @@ class Interpolation(nn.Module):
|
|||||||
class FADE(nn.Module):
|
class FADE(nn.Module):
|
||||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
self.norm = nn.InstanceNorm2d(num_features=in_channels)
|
||||||
|
|
||||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||||
padding_mode="zeros")
|
padding_mode="zeros")
|
||||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||||
@ -71,7 +72,7 @@ class FADE(nn.Module):
|
|||||||
def forward(self, x, feature):
|
def forward(self, x, feature):
|
||||||
alpha = self.alpha_conv(feature)
|
alpha = self.alpha_conv(feature)
|
||||||
beta = self.beta_conv(feature)
|
beta = self.beta_conv(feature)
|
||||||
x = self.bn(x)
|
x = self.norm(x)
|
||||||
return alpha * x + beta
|
return alpha * x + beta
|
||||||
|
|
||||||
|
|
||||||
@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
|
|||||||
self.use_spectral = use_spectral
|
self.use_spectral = use_spectral
|
||||||
|
|
||||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||||
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
|
||||||
self.content_stream = self.build_stream()
|
self.content_stream = self.build_stream()
|
||||||
self.style_stream = self.build_stream()
|
|
||||||
self.generator = self.build_generator()
|
self.generator = self.build_generator()
|
||||||
self.end_conv = nn.Sequential(
|
self.end_conv = nn.Sequential(
|
||||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||||
@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
|
|||||||
m = self.num_blocks - i
|
m = self.num_blocks - i
|
||||||
multiple_prev = multiple_now
|
multiple_prev = multiple_now
|
||||||
multiple_now = min(2 ** m, 2 ** 4)
|
multiple_now = min(2 ** m, 2 ** 4)
|
||||||
stream_sequence.append(nn.Sequential(
|
stream_sequence.append(
|
||||||
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
|
||||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||||
multiple_now * self.base_channels)
|
multiple_now * self.base_channels))
|
||||||
))
|
|
||||||
return nn.ModuleList(stream_sequence)
|
return nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||||
@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
|
|||||||
))
|
))
|
||||||
return nn.ModuleList(stream_sequence)
|
return nn.ModuleList(stream_sequence)
|
||||||
|
|
||||||
def forward(self, content_img, style_img):
|
def forward(self, content_img):
|
||||||
c = self.content_input_layer(content_img)
|
c = self.content_input_layer(content_img)
|
||||||
s = self.style_input_layer(style_img)
|
|
||||||
content_features = []
|
content_features = []
|
||||||
style_features = []
|
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
s = self.style_stream[i](s)
|
|
||||||
c = self.content_stream[i](c)
|
c = self.content_stream[i](c)
|
||||||
content_features.append(c)
|
content_features.append(c)
|
||||||
style_features.append(s)
|
|
||||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||||
|
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
m = - i - 1
|
m = - i - 1
|
||||||
layer = self.generator[i]
|
layer = self.generator[i]
|
||||||
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
z = layer(z, content_features[m])
|
||||||
z = layer[0](z)
|
|
||||||
z = layer[1](z, content_features[m])
|
|
||||||
return self.end_conv(z)
|
return self.end_conv(z)
|
||||||
|
|||||||
14
tool/inspect_model.py
Normal file
14
tool/inspect_model.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from engine.util.build import build_model
|
||||||
|
|
||||||
|
config = OmegaConf.load(sys.argv[1])
|
||||||
|
|
||||||
|
|
||||||
|
generator = build_model(config.model.generator)
|
||||||
|
|
||||||
|
ckp = torch.load(sys.argv[2], map_location="cpu")
|
||||||
|
|
||||||
|
generator.module.load_state_dict(ckp["generator_main"])
|
||||||
13
tool/process/permutation_face.py
Normal file
13
tool/process/permutation_face.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import permutations
|
||||||
|
|
||||||
|
pids = defaultdict(list)
|
||||||
|
for p in Path(sys.argv[1]).glob("*.jpg"):
|
||||||
|
pids[p.stem[:7]].append(p.stem)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for p in pids:
|
||||||
|
data.extend(list(permutations(pids[p], 2)))
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user