from scipy.io import loadmat import torch import torchvision import lmdb import os import pickle from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder from pathlib import Path from collections import defaultdict class CARS(Dataset): def __init__(self, root, loader=default_loader, transform=None): self.root = Path(root) self.transform = transform self.loader = loader self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} self.classes_list = defaultdict(list) for i in range(len(self.annotations)): self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) def __len__(self): return len(self.annotations) def __getitem__(self, item): file_name = "{:05d}.jpg".format(item + 1) target = self.annotations[file_name] sample = self.loader(self.root / "cars_train" / file_name) if self.transform is not None: sample = self.transform(sample) return sample class ImprovedImageFolder(ImageFolder): def __init__(self, root, loader=default_loader, transform=None): super().__init__(root, transform, loader=loader) self.classes_list = defaultdict(list) for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): return super().__getitem__(item) class LMDBDataset(Dataset): def __init__(self, lmdb_path, transform=None): self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False, readahead=False, meminit=False) self.transform = transform with self.db.begin(write=False) as txn: self.classes_list = pickle.loads(txn.get(b"classes_list")) self._len = pickle.loads(txn.get(b"__len__")) def __len__(self): return self._len def __getitem__(self, i): with self.db.begin(write=False) as txn: sample, target = pickle.loads(txn.get("{}".format(i).encode())) if self.transform is not None: sample = self.transform(sample) return sample, target class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_set, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_set = num_set # K self.num_episodes = num_episodes self.t0 = torchvision.transforms.Compose([ # torchvision.transforms.Resize((224, 224)), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def _fetch_list_data(self, id_list): result = [] for i in id_list: img = self.origin[i][0] result.extend([self.t0(img)]) return result def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set_list = [] query_set_list = [] target_list = [] for tag, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) >= self.num_set * 2: # have enough images belong to this class idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_set])) query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_set:])) support_set_list.extend(support) query_set_list.extend(query) target_list.extend([tag] * self.num_set) return { "support": torch.stack(support_set_list), "query": torch.stack(query_set_list), "target": torch.tensor(target_list) } def __repr__(self): return "".format(self.num_class, self.num_set, self.num_episodes)