TAHG 0.0.2
This commit is contained in:
parent
715a2e64a1
commit
89b54105c7
@ -24,7 +24,7 @@ model:
|
||||
generator:
|
||||
_type: TAHG-Generator
|
||||
style_in_channels: 3
|
||||
content_in_channels: 23
|
||||
content_in_channels: 1
|
||||
discriminator:
|
||||
_type: TAHG-Discriminator
|
||||
in_channels: 3
|
||||
@ -73,7 +73,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
batch_size: 48
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
@ -82,12 +82,14 @@ data:
|
||||
_type: GenerationUnpairedDatasetWithEdge
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
edge_type: "hed_landmark"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
edge_type: "hed"
|
||||
size: [128, 128]
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
size: [128, 128]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@ -103,12 +105,14 @@ data:
|
||||
_type: GenerationUnpairedDatasetWithEdge
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
edge_type: "hed_landmark"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
edge_type: "hed"
|
||||
random_pair: False
|
||||
size: [128, 128]
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
size: [128, 128]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@ -2,10 +2,12 @@ import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
||||
|
||||
import lmdb
|
||||
@ -176,17 +178,20 @@ class GenerationUnpairedDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
assert self.edges_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
add = torch.load(op.parent / f"{op.stem}.add")
|
||||
return {"edge": add["edge"].float().unsqueeze(dim=0),
|
||||
"additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||
img = Image.open(edge_path).resize(self.size)
|
||||
return {"edge": F.to_tensor(img)}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
|
||||
0
data/util/__init__.py
Normal file
0
data/util/__init__.py
Normal file
67
data/util/dlib_landmark.py
Normal file
67
data/util/dlib_landmark.py
Normal file
@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from skimage import feature
|
||||
|
||||
# https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
|
||||
DLIB_LANDMARKS_PART_LIST = [
|
||||
[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
|
||||
[range(17, 22)], # right eyebrow
|
||||
[range(22, 27)], # left eyebrow
|
||||
[[28, 31], range(31, 36), [35, 28]], # nose
|
||||
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
|
||||
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
|
||||
[range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth
|
||||
[range(60, 65), [64, 65, 66, 67, 60]] # tongue
|
||||
]
|
||||
|
||||
|
||||
def dist_tensor(key_points, size=(256, 256)):
|
||||
dist_list = []
|
||||
for edge_list in DLIB_LANDMARKS_PART_LIST:
|
||||
for edge in edge_list:
|
||||
pts = key_points[edge, :]
|
||||
im_edge = np.zeros(size, np.uint8)
|
||||
cv2.polylines(im_edge, [pts], isClosed=False, color=255, thickness=1)
|
||||
im_dist = cv2.distanceTransform(255 - im_edge, cv2.DIST_L1, 3)
|
||||
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
||||
dist_list.append(im_dist)
|
||||
return np.stack(dist_list)
|
||||
|
||||
|
||||
def read_keypoints(kp_path, origin_size=(256, 256), size=(256, 256), thickness=1):
|
||||
key_points = np.loadtxt(kp_path, delimiter=",").astype(np.int32)
|
||||
if origin_size != size:
|
||||
# resize key_points using simplest way...
|
||||
key_points = (key_points * (np.array(size) / np.array(origin_size))).astype(np.int32)
|
||||
|
||||
# add upper half face by symmetry
|
||||
face_pts = key_points[:17, :]
|
||||
face_baseline_y = (face_pts[0, 1] + face_pts[-1, 1]) // 2
|
||||
upper_symmetry_face_pts = face_pts[1:-1, :].copy()
|
||||
# keep x untouched
|
||||
upper_symmetry_face_pts[:, 1] = face_baseline_y + (face_baseline_y - upper_symmetry_face_pts[:, 1]) * 2 // 3
|
||||
key_points = np.vstack((key_points, upper_symmetry_face_pts[::-1, :]))
|
||||
|
||||
assert key_points.shape == (83, 2)
|
||||
|
||||
part_labels = np.zeros((len(DLIB_LANDMARKS_PART_LIST), *size), np.uint8)
|
||||
part_edge = np.zeros(size, np.uint8)
|
||||
for i, edge_list in enumerate(DLIB_LANDMARKS_PART_LIST):
|
||||
indices = [item for sublist in edge_list for item in sublist]
|
||||
pts = key_points[indices, :]
|
||||
cv2.fillPoly(part_labels[i], pts=[pts], color=1)
|
||||
if i in [1, 2]:
|
||||
# some part of landmarks is a line
|
||||
cv2.polylines(part_edge, [pts], isClosed=False, color=1, thickness=thickness)
|
||||
else:
|
||||
cv2.drawContours(part_edge, [pts], 0, color=1, thickness=thickness)
|
||||
|
||||
return key_points, part_labels, part_edge
|
||||
|
||||
|
||||
def edge_map(img, part_labels, part_edge, remove_edge_within_face=True):
|
||||
edges = feature.canny(np.array(img.convert("L")))
|
||||
if remove_edge_within_face:
|
||||
edges = edges * (part_labels.sum(0) == 0) # remove edges within face
|
||||
edges = part_edge + edges
|
||||
return edges
|
||||
@ -85,9 +85,7 @@ def get_trainer(config, logger):
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
edge = batch["edge"]
|
||||
additional_info = batch["additional_info"]
|
||||
content_img = torch.cat([edge, additional_info], dim=1)
|
||||
content_img = batch["edge"]
|
||||
fake = dict(
|
||||
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
||||
@ -101,7 +99,7 @@ def get_trainer(config, logger):
|
||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||
_, t = perceptual_loss(fake[d], real[d])
|
||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||
loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
|
||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
@ -146,8 +146,12 @@ class Generator(nn.Module):
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE"),
|
||||
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
@ -168,7 +172,7 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoder(style_img)))
|
||||
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
|
||||
32
tool/process/detect_landmark.py
Normal file
32
tool/process/detect_landmark.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import dlib
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import sys
|
||||
|
||||
imagepaths = Path(sys.argv[1])
|
||||
print(imagepaths)
|
||||
phase = imagepaths.name
|
||||
print(phase)
|
||||
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
||||
|
||||
if not os.path.isdir(phase):
|
||||
os.makedirs(phase)
|
||||
for ip in imagepaths.glob("*.jpg"):
|
||||
img = np.asarray(Image.open(ip))
|
||||
img.setflags(write=True)
|
||||
dets = detector(img, 1)
|
||||
if len(dets) > 0:
|
||||
shape = predictor(img, dets[0])
|
||||
points = np.empty([68, 2], dtype=int)
|
||||
for b in range(68):
|
||||
points[b, 0] = shape.part(b).x
|
||||
points[b, 1] = shape.part(b).y
|
||||
|
||||
save_name = os.path.join(phase, ip.name[:-4] + '.txt')
|
||||
np.savetxt(save_name, points, fmt='%d', delimiter=',')
|
||||
else:
|
||||
print(ip)
|
||||
45
tool/process/generate_edge.py
Normal file
45
tool/process/generate_edge.py
Normal file
@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
from skimage import feature
|
||||
from pathlib import Path
|
||||
from torchvision.datasets.folder import is_image_file, default_loader
|
||||
from torchvision.transforms import functional as F
|
||||
from loss.I2I.edge_loss import HED
|
||||
import torch
|
||||
from PIL import Image
|
||||
import fire
|
||||
|
||||
|
||||
def canny_edge(img):
|
||||
edge = feature.canny(np.array(img.convert("L")))
|
||||
return edge
|
||||
|
||||
|
||||
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
|
||||
assert edge_type in ["canny", "hed"]
|
||||
image_folder = Path(image_folder)
|
||||
save_folder = Path(save_folder)
|
||||
if edge_type == "hed":
|
||||
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
|
||||
elif edge_type == "canny":
|
||||
edge_extractor = canny_edge
|
||||
else:
|
||||
raise NotImplemented
|
||||
for p in image_folder.glob("*"):
|
||||
if is_image_file(p.as_posix()):
|
||||
rgb_img = default_loader(p)
|
||||
print(p)
|
||||
if edge_type == "hed":
|
||||
with torch.no_grad():
|
||||
img_tensor = F.to_tensor(rgb_img).to(device)
|
||||
edge_tensor = edge_extractor(img_tensor)
|
||||
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
|
||||
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
|
||||
elif edge_type == "canny":
|
||||
edge = edge_extractor(rgb_img)
|
||||
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(generate)
|
||||
Loading…
Reference in New Issue
Block a user