From fbea96f6d76b082a2587d0fc36e30dcaad15b6ce Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Thu, 24 Sep 2020 16:50:53 +0800 Subject: [PATCH] add new dataset type --- data/dataset.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/data/dataset.py b/data/dataset.py index 2941889..9e53028 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,6 +1,7 @@ import os import pickle from collections import defaultdict +from itertools import permutations, combinations from pathlib import Path import lmdb @@ -237,3 +238,50 @@ class GenerationUnpairedDatasetWithEdge(Dataset): def __repr__(self): return f"\nPipeline:\n{self.A.pipeline}" + + +@DATASET.register_module() +class PoseFacesWithSingleAnime(Dataset): + def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size, + with_order=True): + self.num_face = num_face + self.landmark_path = Path(landmark_path) + self.with_order = with_order + self.root_face = Path(root_face) + self.root_anime = Path(root_anime) + self.img_size = img_size + self.face_samples = self.iter_folders() + self.face_pipeline = transform_pipeline(face_pipeline) + self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True) + + def iter_folders(self): + pics_per_person = defaultdict(list) + for p in self.root_face.glob("*.jpg"): + pics_per_person[p.stem[:7]].append(p.stem) + data = [] + for p in pics_per_person: + if len(pics_per_person[p]) >= self.num_face: + if self.with_order: + data.extend(list(combinations(pics_per_person[p], self.num_face))) + else: + data.extend(list(permutations(pics_per_person[p], self.num_face))) + return data + + def read_pose(self, pose_txt): + key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size) + dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size))) + part_labels = normalize_tensor(torch.from_numpy(part_labels)) + part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() + return torch.cat([part_labels, part_edge, dist_tensor]) + + def __len__(self): + return len(self.face_samples) + + def __getitem__(self, idx): + output = dict() + output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()] + for i, f in enumerate(self.face_samples[idx]): + output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg") + output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt") + output[f"stem_{i}"] = f + return output