temp commit
This commit is contained in:
parent
776fe40199
commit
6ea13df465
287
data/dataset.py
287
data/dataset.py
@ -1,287 +0,0 @@
|
|||||||
import os
|
|
||||||
import pickle
|
|
||||||
from collections import defaultdict
|
|
||||||
from itertools import permutations, combinations
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lmdb
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torchvision.datasets import ImageFolder
|
|
||||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
|
||||||
from torchvision.transforms import functional as F
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .registry import DATASET
|
|
||||||
from .transform import transform_pipeline
|
|
||||||
from .util import dlib_landmark
|
|
||||||
|
|
||||||
|
|
||||||
def default_transform_way(transform, sample):
|
|
||||||
return [transform(sample[0]), *sample[1:]]
|
|
||||||
|
|
||||||
|
|
||||||
class LMDBDataset(Dataset):
|
|
||||||
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
|
||||||
**lmdb_kwargs):
|
|
||||||
self.path = lmdb_path
|
|
||||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
|
||||||
lock=False, **lmdb_kwargs)
|
|
||||||
|
|
||||||
with self.db.begin(write=False) as txn:
|
|
||||||
self._len = pickle.loads(txn.get(b"$$len$$"))
|
|
||||||
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
|
||||||
if pipeline is None:
|
|
||||||
self.not_done_pipeline = []
|
|
||||||
else:
|
|
||||||
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
|
||||||
self.transform = transform_pipeline(self.not_done_pipeline)
|
|
||||||
self.transform_way = transform_way
|
|
||||||
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
|
||||||
for ea in essential_attr:
|
|
||||||
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
|
||||||
|
|
||||||
def _remain_pipeline(self, pipeline):
|
|
||||||
for i, dp in enumerate(self.done_pipeline):
|
|
||||||
if pipeline[i] != dp:
|
|
||||||
raise ValueError(
|
|
||||||
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
|
||||||
return pipeline[len(self.done_pipeline):]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._len
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
with self.db.begin(write=False) as txn:
|
|
||||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
|
||||||
sample = self.transform_way(self.transform, sample)
|
|
||||||
return sample
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def lmdbify(dataset, done_pipeline, lmdb_path):
|
|
||||||
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
|
||||||
with env.begin(write=True) as txn:
|
|
||||||
for i in tqdm(range(len(dataset)), ncols=0):
|
|
||||||
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
|
||||||
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
|
||||||
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
|
||||||
essential_attr = getattr(dataset, "essential_attr", list())
|
|
||||||
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
|
||||||
for ea in essential_attr:
|
|
||||||
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET.register_module()
|
|
||||||
class ImprovedImageFolder(ImageFolder):
|
|
||||||
def __init__(self, root, pipeline):
|
|
||||||
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
|
||||||
self.classes_list = defaultdict(list)
|
|
||||||
self.essential_attr = ["classes_list"]
|
|
||||||
for i in range(len(self)):
|
|
||||||
self.classes_list[self.samples[i][-1]].append(i)
|
|
||||||
assert len(self.classes_list) == len(self.classes)
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodicDataset(Dataset):
|
|
||||||
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
|
||||||
self.origin = origin_dataset
|
|
||||||
self.num_class = num_class
|
|
||||||
assert self.num_class < len(self.origin.classes_list)
|
|
||||||
self.num_query = num_query # K
|
|
||||||
self.num_support = num_support # K
|
|
||||||
self.num_episodes = num_episodes
|
|
||||||
|
|
||||||
def _fetch_list_data(self, id_list):
|
|
||||||
return [self.origin[i][0] for i in id_list]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_episodes
|
|
||||||
|
|
||||||
def __getitem__(self, _):
|
|
||||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
|
||||||
support_set = []
|
|
||||||
query_set = []
|
|
||||||
target_set = []
|
|
||||||
for tag, c in enumerate(random_classes):
|
|
||||||
image_list = self.origin.classes_list[c]
|
|
||||||
|
|
||||||
if len(image_list) >= self.num_query + self.num_support:
|
|
||||||
# have enough images belong to this class
|
|
||||||
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
|
||||||
else:
|
|
||||||
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
|
||||||
|
|
||||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
|
||||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
|
||||||
support_set.extend(support)
|
|
||||||
query_set.extend(query)
|
|
||||||
target_set.extend([tag] * self.num_query)
|
|
||||||
return {
|
|
||||||
"support": torch.stack(support_set),
|
|
||||||
"query": torch.stack(query_set),
|
|
||||||
"target": torch.tensor(target_set)
|
|
||||||
}
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET.register_module()
|
|
||||||
class SingleFolderDataset(Dataset):
|
|
||||||
def __init__(self, root, pipeline, with_path=False):
|
|
||||||
assert os.path.isdir(root)
|
|
||||||
self.root = root
|
|
||||||
samples = []
|
|
||||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
|
||||||
for fn in sorted(fns):
|
|
||||||
path = os.path.join(r, fn)
|
|
||||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
|
||||||
samples.append(path)
|
|
||||||
self.samples = samples
|
|
||||||
self.pipeline = transform_pipeline(pipeline)
|
|
||||||
self.with_path = with_path
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.samples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if not self.with_path:
|
|
||||||
return self.pipeline(self.samples[idx])
|
|
||||||
else:
|
|
||||||
return self.pipeline(self.samples[idx]), self.samples[idx]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET.register_module()
|
|
||||||
class GenerationUnpairedDataset(Dataset):
|
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
|
||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
|
||||||
self.random_pair = random_pair
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
a_idx = idx % len(self.A)
|
|
||||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
|
||||||
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return max(len(self.A), len(self.B))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_tensor(tensor):
|
|
||||||
tensor = tensor.float()
|
|
||||||
tensor -= tensor.min()
|
|
||||||
tensor /= tensor.max()
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET.register_module()
|
|
||||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
|
||||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
|
||||||
with_path=False):
|
|
||||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
|
||||||
self.edge_type = edge_type
|
|
||||||
self.size = size
|
|
||||||
self.edges_path = Path(edges_path)
|
|
||||||
self.landmarks_path = Path(landmarks_path)
|
|
||||||
assert self.edges_path.exists()
|
|
||||||
assert self.landmarks_path.exists()
|
|
||||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
|
||||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
|
||||||
self.random_pair = random_pair
|
|
||||||
self.with_path = with_path
|
|
||||||
|
|
||||||
def get_edge(self, origin_path):
|
|
||||||
op = Path(origin_path)
|
|
||||||
if self.edge_type.startswith("landmark_"):
|
|
||||||
edge_type = self.edge_type.lstrip("landmark_")
|
|
||||||
use_landmark = op.parent.name.endswith("A")
|
|
||||||
else:
|
|
||||||
edge_type = self.edge_type
|
|
||||||
use_landmark = False
|
|
||||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
|
||||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
|
||||||
if not use_landmark:
|
|
||||||
return origin_edge
|
|
||||||
else:
|
|
||||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
|
||||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
|
||||||
|
|
||||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
|
||||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
|
||||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
|
||||||
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
|
||||||
# edges = part_edge + edges
|
|
||||||
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
a_idx = idx % len(self.A)
|
|
||||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
|
||||||
output = dict(a={}, b={})
|
|
||||||
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
|
||||||
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
|
||||||
for p in "ab":
|
|
||||||
output[p]["edge"] = self.get_edge(output[p]["path"])
|
|
||||||
return output
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return max(len(self.A), len(self.B))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
|
||||||
|
|
||||||
|
|
||||||
@DATASET.register_module()
|
|
||||||
class PoseFacesWithSingleAnime(Dataset):
|
|
||||||
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
|
||||||
with_order=True):
|
|
||||||
self.num_face = num_face
|
|
||||||
self.landmark_path = Path(landmark_path)
|
|
||||||
self.with_order = with_order
|
|
||||||
self.root_face = Path(root_face)
|
|
||||||
self.root_anime = Path(root_anime)
|
|
||||||
self.img_size = img_size
|
|
||||||
self.face_samples = self.iter_folders()
|
|
||||||
self.face_pipeline = transform_pipeline(face_pipeline)
|
|
||||||
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
|
||||||
|
|
||||||
def iter_folders(self):
|
|
||||||
pics_per_person = defaultdict(list)
|
|
||||||
for p in self.root_face.glob("*.jpg"):
|
|
||||||
pics_per_person[p.stem[:7]].append(p.stem)
|
|
||||||
data = []
|
|
||||||
for p in pics_per_person:
|
|
||||||
if len(pics_per_person[p]) >= self.num_face:
|
|
||||||
if self.with_order:
|
|
||||||
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
|
||||||
else:
|
|
||||||
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
|
||||||
return data
|
|
||||||
|
|
||||||
def read_pose(self, pose_txt):
|
|
||||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
|
||||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
|
||||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
|
||||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
|
||||||
return torch.cat([part_labels, part_edge, dist_tensor])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.face_samples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
output = dict()
|
|
||||||
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
|
||||||
for i, f in enumerate(self.face_samples[idx]):
|
|
||||||
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
|
||||||
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
|
||||||
output[f"stem_{i}"] = f
|
|
||||||
return output
|
|
||||||
3
data/dataset/__init__.py
Normal file
3
data/dataset/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from util.misc import import_submodules
|
||||||
|
|
||||||
|
__all__ = import_submodules(__name__).keys()
|
||||||
63
data/dataset/few-shot.py
Normal file
63
data/dataset/few-shot.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
|
from data.registry import DATASET
|
||||||
|
from data.transform import transform_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class ImprovedImageFolder(ImageFolder):
|
||||||
|
def __init__(self, root, pipeline):
|
||||||
|
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
||||||
|
self.classes_list = defaultdict(list)
|
||||||
|
self.essential_attr = ["classes_list"]
|
||||||
|
for i in range(len(self)):
|
||||||
|
self.classes_list[self.samples[i][-1]].append(i)
|
||||||
|
assert len(self.classes_list) == len(self.classes)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicDataset(Dataset):
|
||||||
|
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
||||||
|
self.origin = origin_dataset
|
||||||
|
self.num_class = num_class
|
||||||
|
assert self.num_class < len(self.origin.classes_list)
|
||||||
|
self.num_query = num_query # K
|
||||||
|
self.num_support = num_support # K
|
||||||
|
self.num_episodes = num_episodes
|
||||||
|
|
||||||
|
def _fetch_list_data(self, id_list):
|
||||||
|
return [self.origin[i][0] for i in id_list]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_episodes
|
||||||
|
|
||||||
|
def __getitem__(self, _):
|
||||||
|
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||||
|
support_set = []
|
||||||
|
query_set = []
|
||||||
|
target_set = []
|
||||||
|
for tag, c in enumerate(random_classes):
|
||||||
|
image_list = self.origin.classes_list[c]
|
||||||
|
|
||||||
|
if len(image_list) >= self.num_query + self.num_support:
|
||||||
|
# have enough images belong to this class
|
||||||
|
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
||||||
|
else:
|
||||||
|
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
||||||
|
|
||||||
|
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||||
|
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||||
|
support_set.extend(support)
|
||||||
|
query_set.extend(query)
|
||||||
|
target_set.extend([tag] * self.num_query)
|
||||||
|
return {
|
||||||
|
"support": torch.stack(support_set),
|
||||||
|
"query": torch.stack(query_set),
|
||||||
|
"target": torch.tensor(target_set)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
||||||
62
data/dataset/image_translation.py
Normal file
62
data/dataset/image_translation.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||||
|
|
||||||
|
from data.registry import DATASET
|
||||||
|
from data.transform import transform_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class SingleFolderDataset(Dataset):
|
||||||
|
def __init__(self, root, pipeline, with_path=False):
|
||||||
|
assert os.path.isdir(root)
|
||||||
|
self.root = root
|
||||||
|
samples = []
|
||||||
|
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||||
|
for fn in sorted(fns):
|
||||||
|
path = os.path.join(r, fn)
|
||||||
|
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||||
|
samples.append(path)
|
||||||
|
self.samples = samples
|
||||||
|
self.pipeline = transform_pipeline(pipeline)
|
||||||
|
self.with_path = with_path
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
output = dict(img=self.pipeline(self.samples[idx]))
|
||||||
|
if self.with_path:
|
||||||
|
output["path"] = self.samples[idx]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class GenerationUnpairedDataset(Dataset):
|
||||||
|
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||||
|
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||||
|
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||||
|
self.with_path = with_path
|
||||||
|
self.random_pair = random_pair
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
a_idx = idx % len(self.A)
|
||||||
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||||
|
output_a = self.A[a_idx]
|
||||||
|
output_b = self.B[b_idx]
|
||||||
|
output = dict(a=output_a["img"], b=output_b["img"])
|
||||||
|
if self.with_path:
|
||||||
|
output["a_path"] = output_a["path"]
|
||||||
|
output["b_path"] = output_b["path"]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(len(self.A), len(self.B))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||||
65
data/dataset/lmdb.py
Normal file
65
data/dataset/lmdb.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import lmdb
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data.transform import transform_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def default_transform_way(transform, sample):
|
||||||
|
return [transform(sample[0]), *sample[1:]]
|
||||||
|
|
||||||
|
|
||||||
|
class LMDBDataset(Dataset):
|
||||||
|
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
||||||
|
**lmdb_kwargs):
|
||||||
|
self.path = lmdb_path
|
||||||
|
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||||
|
lock=False, **lmdb_kwargs)
|
||||||
|
|
||||||
|
with self.db.begin(write=False) as txn:
|
||||||
|
self._len = pickle.loads(txn.get(b"$$len$$"))
|
||||||
|
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
||||||
|
if pipeline is None:
|
||||||
|
self.not_done_pipeline = []
|
||||||
|
else:
|
||||||
|
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
||||||
|
self.transform = transform_pipeline(self.not_done_pipeline)
|
||||||
|
self.transform_way = transform_way
|
||||||
|
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
||||||
|
for ea in essential_attr:
|
||||||
|
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
||||||
|
|
||||||
|
def _remain_pipeline(self, pipeline):
|
||||||
|
for i, dp in enumerate(self.done_pipeline):
|
||||||
|
if pipeline[i] != dp:
|
||||||
|
raise ValueError(
|
||||||
|
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
||||||
|
return pipeline[len(self.done_pipeline):]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with self.db.begin(write=False) as txn:
|
||||||
|
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||||
|
sample = self.transform_way(self.transform, sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lmdbify(dataset, done_pipeline, lmdb_path):
|
||||||
|
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
||||||
|
with env.begin(write=True) as txn:
|
||||||
|
for i in tqdm(range(len(dataset)), ncols=0):
|
||||||
|
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
||||||
|
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
||||||
|
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
||||||
|
essential_attr = getattr(dataset, "essential_attr", list())
|
||||||
|
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
||||||
|
for ea in essential_attr:
|
||||||
|
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
||||||
122
data/dataset/pose_transfer.py
Normal file
122
data/dataset/pose_transfer.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from itertools import permutations, combinations
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
from data.registry import DATASET
|
||||||
|
from data.transform import transform_pipeline
|
||||||
|
from data.util import dlib_landmark
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_tensor(tensor):
|
||||||
|
tensor = tensor.float()
|
||||||
|
tensor -= tensor.min()
|
||||||
|
tensor /= tensor.max()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||||
|
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||||
|
with_path=False):
|
||||||
|
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||||
|
self.edge_type = edge_type
|
||||||
|
self.size = size
|
||||||
|
self.edges_path = Path(edges_path)
|
||||||
|
self.landmarks_path = Path(landmarks_path)
|
||||||
|
assert self.edges_path.exists()
|
||||||
|
assert self.landmarks_path.exists()
|
||||||
|
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||||
|
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||||
|
self.random_pair = random_pair
|
||||||
|
self.with_path = with_path
|
||||||
|
|
||||||
|
def get_edge(self, origin_path):
|
||||||
|
op = Path(origin_path)
|
||||||
|
if self.edge_type.startswith("landmark_"):
|
||||||
|
edge_type = self.edge_type.lstrip("landmark_")
|
||||||
|
use_landmark = op.parent.name.endswith("A")
|
||||||
|
else:
|
||||||
|
edge_type = self.edge_type
|
||||||
|
use_landmark = False
|
||||||
|
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||||
|
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
||||||
|
if not use_landmark:
|
||||||
|
return origin_edge
|
||||||
|
else:
|
||||||
|
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
||||||
|
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||||
|
|
||||||
|
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
||||||
|
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||||
|
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||||
|
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||||
|
# edges = part_edge + edges
|
||||||
|
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
a_idx = idx % len(self.A)
|
||||||
|
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||||
|
output = dict(a={}, b={})
|
||||||
|
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
||||||
|
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
||||||
|
for p in "ab":
|
||||||
|
output[p]["edge"] = self.get_edge(output[p]["path"])
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return max(len(self.A), len(self.B))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET.register_module()
|
||||||
|
class PoseFacesWithSingleAnime(Dataset):
|
||||||
|
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
||||||
|
with_order=True):
|
||||||
|
self.num_face = num_face
|
||||||
|
self.landmark_path = Path(landmark_path)
|
||||||
|
self.with_order = with_order
|
||||||
|
self.root_face = Path(root_face)
|
||||||
|
self.root_anime = Path(root_anime)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.face_samples = self.iter_folders()
|
||||||
|
self.face_pipeline = transform_pipeline(face_pipeline)
|
||||||
|
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
||||||
|
|
||||||
|
def iter_folders(self):
|
||||||
|
pics_per_person = defaultdict(list)
|
||||||
|
for p in self.root_face.glob("*.jpg"):
|
||||||
|
pics_per_person[p.stem[:7]].append(p.stem)
|
||||||
|
data = []
|
||||||
|
for p in pics_per_person:
|
||||||
|
if len(pics_per_person[p]) >= self.num_face:
|
||||||
|
if self.with_order:
|
||||||
|
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
||||||
|
else:
|
||||||
|
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def read_pose(self, pose_txt):
|
||||||
|
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
||||||
|
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
||||||
|
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||||
|
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||||
|
return torch.cat([part_labels, part_edge, dist_tensor])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.face_samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
output = dict()
|
||||||
|
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
||||||
|
for i, f in enumerate(self.face_samples[idx]):
|
||||||
|
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
||||||
|
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
||||||
|
output[f"stem_{i}"] = f
|
||||||
|
return output
|
||||||
@ -28,7 +28,7 @@ class Load:
|
|||||||
|
|
||||||
|
|
||||||
def transform_pipeline(pipeline_description):
|
def transform_pipeline(pipeline_description):
|
||||||
if len(pipeline_description) == 0:
|
if pipeline_description is None or len(pipeline_description) == 0:
|
||||||
return lambda x: x
|
return lambda x: x
|
||||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||||
return transforms.Compose(transform_list)
|
return transforms.Compose(transform_list)
|
||||||
|
|||||||
@ -0,0 +1,3 @@
|
|||||||
|
from util.misc import import_submodules
|
||||||
|
|
||||||
|
__all__ = import_submodules(__name__).keys()
|
||||||
@ -1,10 +1,4 @@
|
|||||||
from model.registry import MODEL, NORMALIZATION
|
from model.registry import MODEL, NORMALIZATION
|
||||||
import model.GAN.CycleGAN
|
import model.GAN
|
||||||
import model.GAN.MUNIT
|
|
||||||
import model.GAN.TAFG
|
|
||||||
import model.GAN.TSIT
|
|
||||||
import model.GAN.UGATIT
|
|
||||||
import model.GAN.base
|
|
||||||
import model.GAN.wrapper
|
|
||||||
import model.base.normalization
|
import model.base.normalization
|
||||||
|
|
||||||
|
|||||||
20
util/misc.py
20
util/misc.py
@ -1,4 +1,6 @@
|
|||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import pkgutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -12,6 +14,24 @@ def add_spectral_norm(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def import_submodules(package, recursive=True):
|
||||||
|
""" Import all submodules of a module, recursively, including subpackages
|
||||||
|
|
||||||
|
:param package: package (name or actual module)
|
||||||
|
:type package: str | module
|
||||||
|
:rtype: dict[str, types.ModuleType]
|
||||||
|
"""
|
||||||
|
if isinstance(package, str):
|
||||||
|
package = importlib.import_module(package)
|
||||||
|
results = {}
|
||||||
|
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
|
||||||
|
full_name = package.__name__ + '.' + name
|
||||||
|
results[name] = importlib.import_module(full_name)
|
||||||
|
if recursive and is_pkg:
|
||||||
|
results.update(import_submodules(full_name))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
level: int = logging.INFO,
|
level: int = logging.INFO,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user