from scipy.io import loadmat import torch import lmdb import os import pickle from PIL import Image from io import BytesIO from torch.utils.data import Dataset from torchvision.datasets.folder import default_loader from torchvision.datasets import ImageFolder from torchvision import transforms from torchvision.transforms import functional from pathlib import Path from collections import defaultdict class _CARS(Dataset): def __init__(self, root, loader=default_loader, transform=None): self.root = Path(root) self.transform = transform self.loader = loader self.annotations = loadmat(self.root / "devkit/cars_train_annos.mat")["annotations"][0] self.annotations = {d[-1].item(): d[-2].item() - 1 for d in self.annotations} self.classes_list = defaultdict(list) for i in range(len(self.annotations)): self.classes_list[self.annotations["{:05d}.jpg".format(i + 1)]].append(i) def __len__(self): return len(self.annotations) def __getitem__(self, item): file_name = "{:05d}.jpg".format(item + 1) target = self.annotations[file_name] sample = self.loader(self.root / "cars_train" / file_name) if self.transform is not None: sample = self.transform(sample) return sample, target class ImprovedImageFolder(ImageFolder): def __init__(self, root, loader=default_loader, transform=None): super().__init__(root, transform, loader=loader) self.classes_list = defaultdict(list) for i in range(len(self)): self.classes_list[self.samples[i][-1]].append(i) assert len(self.classes_list) == len(self.classes) def __getitem__(self, item): return super().__getitem__(item) class LMDBDataset(Dataset): def __init__(self, lmdb_path, transform=None): self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False, readahead=False, meminit=False) self.transform = transform with self.db.begin(write=False) as txn: self.classes_list = pickle.loads(txn.get(b"classes_list")) self._len = pickle.loads(txn.get(b"__len__")) def __len__(self): return self._len def __getitem__(self, i): with self.db.begin(write=False) as txn: sample, target = pickle.loads(txn.get("{}".format(i).encode())) if self.transform is not None: sample = self.transform(sample) return sample, target class EpisodicDataset(Dataset): def __init__(self, origin_dataset, num_class, num_set, num_episodes): self.origin = origin_dataset self.num_class = num_class assert self.num_class < len(self.origin.classes_list) self.num_set = num_set # K self.num_episodes = num_episodes self.t0 = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def apply_transform(self, img): # img1 = self.transform(img) # img2 = self.transform(img) # return [self.t0(img), self.t0(functional.hflip(img))] return [self.t0(img)] def __len__(self): return self.num_episodes def __getitem__(self, _): random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist() support_set_list = [] query_set_list = [] target_list = [] for i, c in enumerate(random_classes): image_list = self.origin.classes_list[c] if len(image_list) > self.num_set * 2: idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist() else: idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist() support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]] query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]] support_set_list.extend(sum(map(self.apply_transform, support), list())) query_set_list.extend(sum(map(self.apply_transform, query), list())) target_list.extend([i] * self.num_set) return { "support": torch.stack(support_set_list), "query": torch.stack(query_set_list), "target": torch.tensor(target_list) } def __repr__(self): return "".format(self.num_class, self.num_set, self.num_episodes)