import torch from torch.utils.data import DataLoader import torchvision from data import dataset import argparse from ignite.utils import convert_tensor import time from importlib import import_module from tqdm import tqdm def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ # x: B x N x D # y: B x M x D n = x.size(-2) m = y.size(-2) d = x.size(-1) if d != y.size(-1): raise Exception x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D y = y.unsqueeze(1).expand(x.size(0), n, m, d) return torch.pow(x - y, 2).sum(-1) def evaluate(query, target, support): """ :param query: B x NK x D vector :param target: B x NK vector :param support: B x N x K x D vector :return: """ prototypes = support.mean(-2) # B x N x D distance = euclidean_dist(query, prototypes) # B x NK x N indices = distance.argmin(-1) # B x NK return torch.eq(target, indices).float().mean() def test(lmdb_path, import_path): origin_dataset = dataset.LMDBDataset(lmdb_path, transform=None) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") submit = import_module(f"submit.{import_path}") extractor = submit.make_model() extractor.to(device) batch_size = 10 N = 5 K = 5 episodic_dataset = dataset.EpisodicDataset(origin_dataset, N, K, 100) data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=False) with torch.no_grad(): accs = [] for item in tqdm(data_loader): item = convert_tensor(item, device, non_blocking=True) # item["query"]: B x NKA x 3 x W x H # item["support"]: B x NKA x 3 x W x H # item["target"]: B x NK A = item["query"].size(1) // item["target"].size(1) image_size = item["query"].shape[-3:] query_batch = extractor(item["query"].view([-1, *image_size])).view(batch_size, N * K, A, -1) support_batch = extractor(item["support"].view([-1, *image_size])).view(batch_size, N, K, A, -1) query_batch = torch.mean(query_batch, dim=-2) support_batch = torch.mean(support_batch, dim=-2) assert query_batch.shape[:2] == item["target"].shape[:2] accs.append(evaluate(query_batch, item["target"], support_batch)) r = torch.tensor(accs).mean().item() print(lmdb_path, r) return r if __name__ == '__main__': setup_seed(100) defined_path = [ "/data/few-shot/lmdb256/dogs.lmdb", "/data/few-shot/lmdb256/flowers.lmdb", "/data/few-shot/lmdb256/256-object.lmdb", "/data/few-shot/lmdb256/dtd.lmdb", "/data/few-shot/lmdb256/cars_train.lmdb", "/data/few-shot/lmdb256/cub.lmdb", ] parser = argparse.ArgumentParser(description="test") parser.add_argument('-i', "--import_path", required=True) args = parser.parse_args() for path in defined_path: test(path, args.import_path)