from scipy.io import loadmat import torch import lmdb import os import pickle from io import BytesIO from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder from pathlib import Path from collections import defaultdict class CARS(Dataset): def __init__(self, root, loader=default_loader, transform=None): self.root = Path(root) self.transform = transform self.loader = loader self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} self.classes_list = defaultdict(list) for i in range(len(self.annotations)): self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) def __len__(self): return len(self.annotations) def __getitem__(self, item): file_name = "{:05d}.jpg".format(item + 1) target = self.annotations[file_name] sample = self.loader(self.root / "cars_train" / file_name) if self.transform is not None: sample = self.transform(sample) return sample class ImprovedImageFolder(ImageFolder): def __init__(self, root, loader=default_loader, transform=None): super().__init__(root, transform, loader=loader) self.classes_list = defaultdict(list) for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): return super().__getitem__(item)[0] class LMDBDataset(Dataset): def __init__(self, lmdb_path, transform=None): self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False, readahead=False, meminit=False) self.transform = transform with self.db.begin(write=False) as txn: self.classes_list = pickle.loads(txn.get(b"classes_list")) self._len = pickle.loads(txn.get(b"__len__")) def __len__(self): return self._len def __getitem__(self, i): with self.db.begin(write=False) as txn: sample = torch.load(BytesIO(txn.get("{}".format(i).encode()))) if self.transform is not None: sample = self.transform(sample) return sample class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_set, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_set = num_set # K self.num_episodes = num_episodes def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set_list = [] query_set_list = [] target_list = [] for i, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) > self.num_set * 2: idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]]) query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]]) target_list.extend([i] * self.num_set) return { "support": torch.stack(support_set_list), "query": torch.stack(query_set_list), "target": torch.tensor(target_list) } def __repr__(self): return "".format(self.num_class, self.num_set, self.num_episodes)