TAHG 0.0.2

This commit is contained in:
budui 2020-08-30 14:44:40 +08:00
parent 715a2e64a1
commit 89b54105c7
8 changed files with 172 additions and 17 deletions

View File

@ -24,7 +24,7 @@ model:
generator:
_type: TAHG-Generator
style_in_channels: 3
content_in_channels: 23
content_in_channels: 1
discriminator:
_type: TAHG-Discriminator
in_channels: 3
@ -73,7 +73,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 4
batch_size: 48
shuffle: True
num_workers: 2
pin_memory: True
@ -82,12 +82,14 @@ data:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edge_type: "hed_landmark"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
size: [128, 128]
random_pair: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
@ -103,12 +105,14 @@ data:
_type: GenerationUnpairedDatasetWithEdge
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edge_type: "hed_landmark"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
random_pair: False
size: [128, 128]
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
size: [128, 128]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@ -2,10 +2,12 @@ import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
@ -176,17 +178,20 @@ class GenerationUnpairedDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
add = torch.load(op.parent / f"{op.stem}.add")
return {"edge": add["edge"].float().unsqueeze(dim=0),
"additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return {"edge": F.to_tensor(img)}
def __getitem__(self, idx):
a_idx = idx % len(self.A)

0
data/util/__init__.py Normal file
View File

View File

@ -0,0 +1,67 @@
import numpy as np
import cv2
from skimage import feature
# https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
DLIB_LANDMARKS_PART_LIST = [
[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
[range(17, 22)], # right eyebrow
[range(22, 27)], # left eyebrow
[[28, 31], range(31, 36), [35, 28]], # nose
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
[range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth
[range(60, 65), [64, 65, 66, 67, 60]] # tongue
]
def dist_tensor(key_points, size=(256, 256)):
dist_list = []
for edge_list in DLIB_LANDMARKS_PART_LIST:
for edge in edge_list:
pts = key_points[edge, :]
im_edge = np.zeros(size, np.uint8)
cv2.polylines(im_edge, [pts], isClosed=False, color=255, thickness=1)
im_dist = cv2.distanceTransform(255 - im_edge, cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
dist_list.append(im_dist)
return np.stack(dist_list)
def read_keypoints(kp_path, origin_size=(256, 256), size=(256, 256), thickness=1):
key_points = np.loadtxt(kp_path, delimiter=",").astype(np.int32)
if origin_size != size:
# resize key_points using simplest way...
key_points = (key_points * (np.array(size) / np.array(origin_size))).astype(np.int32)
# add upper half face by symmetry
face_pts = key_points[:17, :]
face_baseline_y = (face_pts[0, 1] + face_pts[-1, 1]) // 2
upper_symmetry_face_pts = face_pts[1:-1, :].copy()
# keep x untouched
upper_symmetry_face_pts[:, 1] = face_baseline_y + (face_baseline_y - upper_symmetry_face_pts[:, 1]) * 2 // 3
key_points = np.vstack((key_points, upper_symmetry_face_pts[::-1, :]))
assert key_points.shape == (83, 2)
part_labels = np.zeros((len(DLIB_LANDMARKS_PART_LIST), *size), np.uint8)
part_edge = np.zeros(size, np.uint8)
for i, edge_list in enumerate(DLIB_LANDMARKS_PART_LIST):
indices = [item for sublist in edge_list for item in sublist]
pts = key_points[indices, :]
cv2.fillPoly(part_labels[i], pts=[pts], color=1)
if i in [1, 2]:
# some part of landmarks is a line
cv2.polylines(part_edge, [pts], isClosed=False, color=1, thickness=thickness)
else:
cv2.drawContours(part_edge, [pts], 0, color=1, thickness=thickness)
return key_points, part_labels, part_edge
def edge_map(img, part_labels, part_edge, remove_edge_within_face=True):
edges = feature.canny(np.array(img.convert("L")))
if remove_edge_within_face:
edges = edges * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return edges

View File

@ -85,9 +85,7 @@ def get_trainer(config, logger):
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
edge = batch["edge"]
additional_info = batch["additional_info"]
content_img = torch.cat([edge, additional_info], dim=1)
content_img = batch["edge"]
fake = dict(
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
@ -101,7 +99,7 @@ def get_trainer(config, logger):
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
sum(loss_g.values()).backward()
optimizers["g"].step()

View File

@ -146,8 +146,12 @@ class Generator(nn.Module):
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
self.style_encoders = nn.ModuleDict({
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE"),
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
@ -168,7 +172,7 @@ class Generator(nn.Module):
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
styles = self.fusion(self.fc(self.style_encoder(style_img)))
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_res):
ar.norm1.set_style(styles[2 * i])

View File

@ -0,0 +1,32 @@
import os
import numpy as np
import dlib
from pathlib import Path
from PIL import Image
import sys
imagepaths = Path(sys.argv[1])
print(imagepaths)
phase = imagepaths.name
print(phase)
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
if not os.path.isdir(phase):
os.makedirs(phase)
for ip in imagepaths.glob("*.jpg"):
img = np.asarray(Image.open(ip))
img.setflags(write=True)
dets = detector(img, 1)
if len(dets) > 0:
shape = predictor(img, dets[0])
points = np.empty([68, 2], dtype=int)
for b in range(68):
points[b, 0] = shape.part(b).x
points[b, 1] = shape.part(b).y
save_name = os.path.join(phase, ip.name[:-4] + '.txt')
np.savetxt(save_name, points, fmt='%d', delimiter=',')
else:
print(ip)

View File

@ -0,0 +1,45 @@
import numpy as np
from skimage import feature
from pathlib import Path
from torchvision.datasets.folder import is_image_file, default_loader
from torchvision.transforms import functional as F
from loss.I2I.edge_loss import HED
import torch
from PIL import Image
import fire
def canny_edge(img):
edge = feature.canny(np.array(img.convert("L")))
return edge
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
assert edge_type in ["canny", "hed"]
image_folder = Path(image_folder)
save_folder = Path(save_folder)
if edge_type == "hed":
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
elif edge_type == "canny":
edge_extractor = canny_edge
else:
raise NotImplemented
for p in image_folder.glob("*"):
if is_image_file(p.as_posix()):
rgb_img = default_loader(p)
print(p)
if edge_type == "hed":
with torch.no_grad():
img_tensor = F.to_tensor(rgb_img).to(device)
edge_tensor = edge_extractor(img_tensor)
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
elif edge_type == "canny":
edge = edge_extractor(rgb_img)
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
else:
raise NotImplemented
if __name__ == '__main__':
fire.Fire(generate)