raycv/data/dataset/pose_transfer.py
2020-10-10 10:43:00 +08:00

123 lines
5.2 KiB
Python

from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from data.registry import DATASET
from data.transform import transform_pipeline
from data.util import dlib_landmark
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output