Compare commits
7 Commits
b01016edb5
...
776fe40199
| Author | SHA1 | Date | |
|---|---|---|---|
| 776fe40199 | |||
| f67bcdf161 | |||
| 16f18ab2e2 | |||
| 0f2b67e215 | |||
| acf243cb12 | |||
| fbea96f6d7 | |||
| ca55318253 |
@ -19,6 +19,7 @@ handler:
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
add_new_loss_epoch: -1
|
||||
|
||||
model:
|
||||
generator:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: self2anime-TSIT
|
||||
name: VoxCeleb2Anime-TSIT
|
||||
engine: TSIT
|
||||
result_dir: ./result
|
||||
max_pairs: 1500000
|
||||
@ -11,7 +11,10 @@ handler:
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
|
||||
misc:
|
||||
@ -86,24 +89,23 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDatasetWithEdge
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
size: [ 128, 128 ]
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@ -118,13 +120,14 @@ data:
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
with_path: True
|
||||
root_a: "/data/i2i/faces/CelebA-Asian/testA"
|
||||
root_b: "/data/i2i/anime/your-name/faces"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 170, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
|
||||
171
configs/synthesizers/talking_anime.yml
Normal file
171
configs/synthesizers/talking_anime.yml
Normal file
@ -0,0 +1,171 @@
|
||||
name: talking_anime
|
||||
engine: talking_anime
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 100 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
fm:
|
||||
level: 1
|
||||
weight: 1
|
||||
style:
|
||||
layer_weights:
|
||||
"3": 1
|
||||
criterion: 'L1'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 10
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0
|
||||
context:
|
||||
layer_weights:
|
||||
#"13": 1
|
||||
"22": 1
|
||||
weight: 5
|
||||
recon:
|
||||
level: 1
|
||||
weight: 10
|
||||
edge:
|
||||
weight: 5
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
model:
|
||||
face_generator:
|
||||
_type: TAFG-SingleGenerator
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
use_spectral_norm: True
|
||||
style_encoder_type: VGG19StyleEncoder
|
||||
num_style_conv: 4
|
||||
style_dim: 512
|
||||
num_adain_blocks: 8
|
||||
num_res_blocks: 8
|
||||
anime_generator:
|
||||
_type: TAFG-ResGenerator
|
||||
_bn_to_sync_bn: False
|
||||
in_channels: 6
|
||||
use_spectral_norm: True
|
||||
num_res_blocks: 8
|
||||
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 1
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 144, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from itertools import permutations, combinations
|
||||
from pathlib import Path
|
||||
|
||||
import lmdb
|
||||
@ -237,3 +238,50 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class PoseFacesWithSingleAnime(Dataset):
|
||||
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
||||
with_order=True):
|
||||
self.num_face = num_face
|
||||
self.landmark_path = Path(landmark_path)
|
||||
self.with_order = with_order
|
||||
self.root_face = Path(root_face)
|
||||
self.root_anime = Path(root_anime)
|
||||
self.img_size = img_size
|
||||
self.face_samples = self.iter_folders()
|
||||
self.face_pipeline = transform_pipeline(face_pipeline)
|
||||
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
||||
|
||||
def iter_folders(self):
|
||||
pics_per_person = defaultdict(list)
|
||||
for p in self.root_face.glob("*.jpg"):
|
||||
pics_per_person[p.stem[:7]].append(p.stem)
|
||||
data = []
|
||||
for p in pics_per_person:
|
||||
if len(pics_per_person[p]) >= self.num_face:
|
||||
if self.with_order:
|
||||
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
||||
else:
|
||||
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
||||
return data
|
||||
|
||||
def read_pose(self, pose_txt):
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
return torch.cat([part_labels, part_edge, dist_tensor])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.face_samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict()
|
||||
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
||||
for i, f in enumerate(self.face_samples[idx]):
|
||||
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
||||
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
||||
output[f"stem_{i}"] = f
|
||||
return output
|
||||
|
||||
133
engine/TAFG.py
133
engine/TAFG.py
@ -76,11 +76,14 @@ class TAFGEngineKernel(EngineKernel):
|
||||
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||
for ph in "ab":
|
||||
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||
images["a2b"], "b", "b")
|
||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
|
||||
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||
images["a2b"], "b", "b")
|
||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
@ -91,50 +94,76 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
|
||||
loss[f"gan_{ph}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch["a"]["img"], generated["images"]["a2b"]
|
||||
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
|
||||
self.generators["main"].style_converters.requires_grad_(False)
|
||||
self.generators["main"].style_encoders.requires_grad_(False)
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
|
||||
loss["gan_a2b"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
|
||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["random_b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
for ph in "ab":
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch["a"]["img"], generated["images"]["a2b"]
|
||||
)
|
||||
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
|
||||
)
|
||||
if self.config.loss.style.weight > 0:
|
||||
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch["a"]["img"], generated["images"][f"cycle_a"]
|
||||
)
|
||||
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
# for ph in "ab":
|
||||
#
|
||||
# if self.config.loss.style.weight > 0:
|
||||
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
# batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
# )
|
||||
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
# batch = self._process_batch(batch)
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
|
||||
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
|
||||
else:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
@ -145,18 +174,30 @@ class TAFGEngineKernel(EngineKernel):
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
batch = self._process_batch(batch)
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
generated["images"]["a2b"].detach(),
|
||||
generated["images"]["cycle_a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["cycle_b"].detach()]
|
||||
)
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
generated["images"]["a2b"].detach(),
|
||||
generated["images"]["cycle_a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["cycle_b"].detach()]
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
pass
|
||||
|
||||
@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
||||
b=self.generators["main"](content_img=batch["a"])
|
||||
)
|
||||
return fake
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
|
||||
for phase in "b":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
@ -189,34 +189,33 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for i in range(len(image_list)):
|
||||
test_images[k].append([])
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
generated = kernel.forward(batch, inference=True)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
for j in range(len(images[k])):
|
||||
test_images[k][j].append(images[k][j])
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
for j in range(len(images[k])):
|
||||
test_images[k][j].append(images[k][j])
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
|
||||
153
engine/talking_anime.py
Normal file
153
engine/talking_anime.py
Normal file
@ -0,0 +1,153 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.context_loss import ContextLoss
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
context_loss_cfg = OmegaConf.to_container(config.loss.context)
|
||||
context_loss_cfg.pop("weight")
|
||||
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
|
||||
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
anime=build_model(self.config.model.anime_generator),
|
||||
face=build_model(self.config.model.face_generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
anime=build_model(self.config.model.discriminator),
|
||||
face=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["face"])
|
||||
self.logger.debug(generators["face"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
target_pose_anime = self.generators["anime"](
|
||||
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
|
||||
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
|
||||
|
||||
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
|
||||
|
||||
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
|
||||
pred_fake = discriminator(generated_img)
|
||||
loss_gan = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if match_img is None:
|
||||
# do not cal feature match loss
|
||||
return loss_gan, 0
|
||||
|
||||
pred_real = discriminator(match_img)
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss_fm = self.config.loss.fm.weight * loss_fm
|
||||
return loss_gan, loss_fm
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["face"], generated["fake_face"], batch["face_1"])
|
||||
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
|
||||
|
||||
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
|
||||
)
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
generated["fake_anime"], batch["anime_img"]
|
||||
)
|
||||
if self.config.loss.context.weight > 0:
|
||||
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
|
||||
generated["fake_anime"], batch["anime_img"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
real = {"anime": "anime_img", "face": "face_1"}
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[real[phase]])
|
||||
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
|
||||
generated["fake_face"].detach()]
|
||||
return dict(
|
||||
b=[img for img in images]
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
44
loss/I2I/context_loss.py
Normal file
44
loss/I2I/context_loss.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .perceptual_loss import PerceptualVGG
|
||||
|
||||
|
||||
class ContextLoss(nn.Module):
|
||||
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
|
||||
eps=1e-5):
|
||||
super(ContextLoss, self).__init__()
|
||||
self.eps = eps
|
||||
self.h = h
|
||||
self.layer_weights = layer_weights
|
||||
self.norm_img = norm_img
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
def single_forward(self, source_feature, target_feature):
|
||||
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
|
||||
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
source_feature = F.normalize(source_feature, p=2, dim=1)
|
||||
target_feature = F.normalize(target_feature, p=2, dim=1)
|
||||
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
|
||||
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
|
||||
w = torch.exp((1 - rel_distance) / self.h)
|
||||
cx = w.div(w.sum(dim=2, keepdim=True))
|
||||
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
|
||||
return -torch.log(cx).mean()
|
||||
|
||||
def forward(self, x, gt):
|
||||
if self.norm_img:
|
||||
x = (x + 1.) * 0.5
|
||||
gt = (gt + 1.) * 0.5
|
||||
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
loss = 0
|
||||
for k in x_features.keys():
|
||||
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
return loss
|
||||
@ -4,6 +4,49 @@ import torch.nn.functional as F
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
# Sequential(
|
||||
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (1): ReLU(inplace=True)
|
||||
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (3): ReLU(inplace=True)
|
||||
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (6): ReLU(inplace=True)
|
||||
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (8): ReLU(inplace=True)
|
||||
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (11): ReLU(inplace=True)
|
||||
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (13): ReLU(inplace=True)
|
||||
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (15): ReLU(inplace=True)
|
||||
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (17): ReLU(inplace=True)
|
||||
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (20): ReLU(inplace=True)
|
||||
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (22): ReLU(inplace=True)
|
||||
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (24): ReLU(inplace=True)
|
||||
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (26): ReLU(inplace=True)
|
||||
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (29): ReLU(inplace=True)
|
||||
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (31): ReLU(inplace=True)
|
||||
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (33): ReLU(inplace=True)
|
||||
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (35): ReLU(inplace=True)
|
||||
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
# )
|
||||
class PerceptualVGG(nn.Module):
|
||||
"""VGG network used in calculating perceptual loss.
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
|
||||
list contains the name each layer in `vgg.feature`. An example
|
||||
of this list is ['4', '10'].
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image.
|
||||
Importantly, the input feature must in the range [0, 1].
|
||||
Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
|
||||
super(PerceptualVGG, self).__init__()
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
self.use_input_norm = norm_image_with_imagenet_param
|
||||
|
||||
# get vgg model and load pretrained vgg weight
|
||||
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
||||
@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
|
||||
in calculating losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
||||
loss will be calculated.
|
||||
@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module):
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
|
||||
style_loss=False, norm_img=True, criterion='L1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.norm_img = norm_img
|
||||
@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module):
|
||||
self.style_loss = style_loss
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
|
||||
|
||||
|
||||
@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-ResGenerator")
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
||||
super().__init__()
|
||||
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.content_encoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-SingleGenerator")
|
||||
class SingleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoder = StyleEncoder(
|
||||
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
||||
n_blocks=3, norm_type="NONE")
|
||||
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
content = self.content_encoder(content_img)
|
||||
style = self.style_encoder(style_img)
|
||||
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return self.decoder(content)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
|
||||
@ -3,45 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
from model.normalization import AdaptiveInstanceNorm2d
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding_mode='zeros', norm_type="IN", use_bias=None,
|
||||
use_spectral=True):
|
||||
super().__init__()
|
||||
self.padding_mode = padding_mode
|
||||
self.use_bias = use_bias
|
||||
self.use_spectral = use_spectral
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
self.use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.main = nn.Sequential(
|
||||
self.conv_block(in_channels, in_channels),
|
||||
norm_layer(in_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
self.conv_block(in_channels, out_channels),
|
||||
norm_layer(out_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
self.skip = nn.Sequential(
|
||||
self.conv_block(in_channels, out_channels, padding=0, kernel_size=1),
|
||||
norm_layer(out_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
def conv_block(self, in_channels, out_channels, kernel_size=3, padding=1):
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding,
|
||||
padding_mode=self.padding_mode, bias=self.use_bias)
|
||||
if self.use_spectral:
|
||||
return nn.utils.spectral_norm(conv)
|
||||
else:
|
||||
return conv
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x) + self.skip(x)
|
||||
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
|
||||
|
||||
|
||||
class Interpolation(nn.Module):
|
||||
@ -59,106 +21,41 @@ class Interpolation(nn.Module):
|
||||
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||
|
||||
|
||||
class FADE(nn.Module):
|
||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
padding_mode="zeros")
|
||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
padding_mode="zeros")
|
||||
|
||||
def forward(self, x, feature):
|
||||
alpha = self.alpha_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
x = self.bn(x)
|
||||
return alpha * x + beta
|
||||
|
||||
|
||||
class FADEResBlock(nn.Module):
|
||||
def __init__(self, use_spectral, features_channels, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
FADE(use_spectral, features_channels, in_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
conv_block(use_spectral, in_channels, in_channels, kernel_size=3, padding=1),
|
||||
FADE(use_spectral, features_channels, in_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
conv_block(use_spectral, in_channels, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
self.skip = nn.Sequential(
|
||||
FADE(use_spectral, features_channels, in_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
conv_block(use_spectral, in_channels, out_channels, kernel_size=1, padding=0),
|
||||
)
|
||||
self.up_sample = Interpolation(2, mode="nearest")
|
||||
|
||||
@staticmethod
|
||||
def forward_with_fade(module, x, feature):
|
||||
for layer in module:
|
||||
if layer.__class__.__name__ == "FADE":
|
||||
x = layer(x, feature)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, feature):
|
||||
out = self.forward_with_fade(self.main, x, feature) + self.forward_with_fade(self.main, x, feature)
|
||||
return self.up_sample(out)
|
||||
|
||||
|
||||
def conv_block(use_spectral, in_channels, out_channels, **kwargs):
|
||||
conv = nn.Conv2d(in_channels, out_channels, **kwargs)
|
||||
return nn.utils.spectral_norm(conv) if use_spectral else conv
|
||||
|
||||
|
||||
@MODEL.register_module("TSIT-Generator")
|
||||
class TSITGenerator(nn.Module):
|
||||
def __init__(self, num_blocks=7, base_channels=64, content_in_channels=3, style_in_channels=3,
|
||||
out_channels=3, use_spectral=True, input_layer_type="conv1x1"):
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
|
||||
padding_mode="reflect", activation_type="ReLU"):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.base_channels = base_channels
|
||||
self.use_spectral = use_spectral
|
||||
|
||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
||||
self.content_stream = self.build_stream()
|
||||
self.style_stream = self.build_stream()
|
||||
self.generator = self.build_generator()
|
||||
self.end_conv = nn.Sequential(
|
||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.content_stream = self.build_stream(padding_mode, activation_type)
|
||||
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
|
||||
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||
|
||||
def build_generator(self):
|
||||
stream_sequence = []
|
||||
sequence = []
|
||||
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
||||
for i in range(1, self.num_blocks + 1):
|
||||
m = self.num_blocks - i
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** m, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||
multiple_now * self.base_channels)
|
||||
sequence.append(nn.Sequential(
|
||||
ReverseResidualBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
padding_mode=padding_mode, norm_type="FADE",
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=multiple_prev * base_channels,
|
||||
base_norm_type="BN",
|
||||
padding_mode=padding_mode
|
||||
)
|
||||
),
|
||||
Interpolation(2, mode="nearest")
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
self.generator = nn.Sequential(*sequence)
|
||||
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
|
||||
kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||
|
||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||
if input_layer_type == "conv7x7":
|
||||
return nn.Sequential(
|
||||
conv_block(self.use_spectral, in_channels, out_channels, kernel_size=7, stride=1,
|
||||
padding_mode="zeros", padding=3, bias=True),
|
||||
select_norm_layer("IN")(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
elif input_layer_type == "conv1x1":
|
||||
return conv_block(self.use_spectral, in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
def build_stream(self):
|
||||
def build_stream(self, padding_mode, activation_type):
|
||||
multiple_now = 1
|
||||
stream_sequence = []
|
||||
for i in range(1, self.num_blocks + 1):
|
||||
@ -166,27 +63,26 @@ class TSITGenerator(nn.Module):
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||
ResBlock(multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
||||
use_spectral=self.use_spectral)
|
||||
ResidualBlock(
|
||||
multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
c = self.content_input_layer(content_img)
|
||||
s = self.style_input_layer(style_img)
|
||||
def forward(self, content, z=None):
|
||||
c = self.start_conv(content)
|
||||
content_features = []
|
||||
style_features = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.style_stream[i](s)
|
||||
c = self.content_stream[i](c)
|
||||
content_features.append(c)
|
||||
style_features.append(s)
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
if z is None:
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
m = - i - 1
|
||||
layer = self.generator[i]
|
||||
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
||||
z = layer[0](z)
|
||||
z = layer[1](z, content_features[m])
|
||||
return self.end_conv(z)
|
||||
res_block = self.generator[i][0]
|
||||
res_block.conv1.normalization.set_feature(content_features[m])
|
||||
res_block.conv2.normalization.set_feature(content_features[m])
|
||||
if res_block.learn_skip_connection:
|
||||
res_block.res_conv.normalization.set_feature(content_features[m])
|
||||
return self.end_conv(self.generator(z))
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from model.registry import MODEL
|
||||
from model.registry import MODEL, NORMALIZATION
|
||||
import model.GAN.CycleGAN
|
||||
import model.GAN.MUNIT
|
||||
import model.GAN.TAFG
|
||||
import model.GAN.UGATIT
|
||||
import model.GAN.wrapper
|
||||
import model.GAN.base
|
||||
import model.GAN.TSIT
|
||||
import model.GAN.MUNIT
|
||||
import model.GAN.UGATIT
|
||||
import model.GAN.base
|
||||
import model.GAN.wrapper
|
||||
import model.base.normalization
|
||||
|
||||
|
||||
0
model/base/__init__.py
Normal file
0
model/base/__init__.py
Normal file
109
model/base/module.py
Normal file
109
model/base/module.py
Normal file
@ -0,0 +1,109 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model.registry import NORMALIZATION
|
||||
|
||||
_DO_NO_THING_FUNC = lambda x: x
|
||||
|
||||
|
||||
def _use_bias_checker(norm_type):
|
||||
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
|
||||
|
||||
|
||||
def _normalization(norm, num_features, additional_kwargs=None):
|
||||
if norm == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
|
||||
if additional_kwargs is None:
|
||||
additional_kwargs = {}
|
||||
kwargs = dict(_type=norm, num_features=num_features)
|
||||
kwargs.update(additional_kwargs)
|
||||
return NORMALIZATION.build_with(kwargs)
|
||||
|
||||
|
||||
def _activation(activation):
|
||||
if activation == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
elif activation == "ReLU":
|
||||
return nn.ReLU(inplace=True)
|
||||
elif activation == "LeakyReLU":
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplemented(activation)
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
|
||||
# if caller not set bias, set bias automatically.
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.normalization = _normalization(norm_type, out_channels)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.convolution(x)))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
||||
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = num_channels
|
||||
if out_activation_type is None:
|
||||
out_activation_type = "NONE"
|
||||
|
||||
self.learn_skip_connection = num_channels != out_channels
|
||||
|
||||
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
return self.conv2(self.conv1(x)) + res
|
||||
|
||||
|
||||
class ReverseConv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.convolution(self.activation(self.normalization(x)))
|
||||
|
||||
|
||||
class ReverseResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
||||
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
||||
super().__init__()
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = ReverseConv2dBlock(
|
||||
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
return self.conv2(self.conv1(x)) + res
|
||||
142
model/base/normalization.py
Normal file
142
model/base/normalization.py
Normal file
@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import NORMALIZATION
|
||||
from model.base.module import Conv2dBlock
|
||||
|
||||
_VALID_NORM_AND_ABBREVIATION = dict(
|
||||
IN="InstanceNorm2d",
|
||||
BN="BatchNorm2d",
|
||||
)
|
||||
|
||||
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("ADE")
|
||||
class AdaptiveDenormalization(nn.Module):
|
||||
def __init__(self, num_features, base_norm_type="BN"):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.base_norm_type = base_norm_type
|
||||
self.norm = self.base_norm(num_features)
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def base_norm(self, num_features):
|
||||
if self.base_norm_type == "IN":
|
||||
return nn.InstanceNorm2d(num_features)
|
||||
elif self.base_norm_type == "BN":
|
||||
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||
|
||||
def set_condition(self, gamma, beta):
|
||||
self.gamma, self.beta = gamma, beta
|
||||
self.have_set_condition = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
self.have_set_condition = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"base_norm_type={self.base_norm_type})"
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaIN")
|
||||
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int):
|
||||
super().__init__(num_features, "IN")
|
||||
self.num_features = num_features
|
||||
|
||||
def set_style(self, style):
|
||||
style = style.view(*style.size(), 1, 1)
|
||||
gamma, beta = style.chunk(2, 1)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("FADE")
|
||||
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
def set_feature(self, feature):
|
||||
gamma = self.gamma_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("SPADE")
|
||||
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
def set_condition_image(self, condition_image):
|
||||
feature = self.base_conv_block(condition_image)
|
||||
gamma = self.gamma_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||
out = out * gamma + beta
|
||||
return out
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("ILN")
|
||||
class ILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super(ILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||
self.gamma = nn.Parameter(torch.Tensor(num_features))
|
||||
self.beta = nn.Parameter(torch.Tensor(num_features))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.rho)
|
||||
nn.init.ones_(self.gamma)
|
||||
nn.init.zeros_(self.beta)
|
||||
|
||||
def forward(self, x):
|
||||
return _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaILN")
|
||||
class AdaILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||
super(AdaILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||
self.rho.data.fill_(default_rho)
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def set_condition(self, gamma, beta):
|
||||
self.gamma, self.beta = gamma, beta
|
||||
self.have_set_condition = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
out = _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
self.have_set_condition = False
|
||||
return out
|
||||
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
def select_norm_layer(norm_type):
|
||||
if norm_type == "BN":
|
||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
return functools.partial(nn.BatchNorm2d)
|
||||
elif norm_type == "IN":
|
||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == "LN":
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from util.registry import Registry
|
||||
|
||||
MODEL = Registry("model")
|
||||
NORMALIZATION = Registry("normalization")
|
||||
14
tool/inspect_model.py
Normal file
14
tool/inspect_model.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.util.build import build_model
|
||||
|
||||
config = OmegaConf.load(sys.argv[1])
|
||||
|
||||
|
||||
generator = build_model(config.model.generator)
|
||||
|
||||
ckp = torch.load(sys.argv[2], map_location="cpu")
|
||||
|
||||
generator.module.load_state_dict(ckp["generator_main"])
|
||||
13
tool/process/permutation_face.py
Normal file
13
tool/process/permutation_face.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from itertools import permutations
|
||||
|
||||
pids = defaultdict(list)
|
||||
for p in Path(sys.argv[1]).glob("*.jpg"):
|
||||
pids[p.stem[:7]].append(p.stem)
|
||||
|
||||
data = []
|
||||
for p in pids:
|
||||
data.extend(list(permutations(pids[p], 2)))
|
||||
|
||||
@ -2,6 +2,15 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: Optional[str] = None,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user