import torch from torch.utils.data import DataLoader import torchvision from data import dataset import torch.nn as nn from ignite.utils import convert_tensor import time def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ # x: B x N x D # y: B x M x D n = x.size(-2) m = y.size(-2) d = x.size(-1) if d != y.size(-1): raise Exception x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D y = y.unsqueeze(1).expand(x.size(0), n, m, d) return torch.pow(x - y, 2).sum(-1) def evaluate(query, target, support): """ :param query: B x NK x D vector :param target: B x NK x 1 vector :param support: B x N x K x D vector :return: """ prototypes = support.mean(-2) # B x N x D distance = euclidean_dist(query, prototypes) # B x NK x N indices = distance.argmin(-1) # B x NK return torch.eq(target, indices).float().mean() class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) def make_extractor(): resnet50 = torchvision.models.resnet50(pretrained=True) resnet50.to(torch.device("cuda")) resnet50.fc = torch.nn.Identity() resnet50.eval() def extract(images): with torch.no_grad(): return resnet50(images) return extract # def make_extractor(): # model = resnet18() # model.to(torch.device("cuda")) # model.eval() # # def extract(images): # with torch.no_grad(): # return model(images) # return extract def resnet18(model_path="ResNet18Official.pth"): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_w_fc = torchvision.models.resnet18(pretrained=False) seq = list(model_w_fc.children())[:-1] seq.append(Flatten()) model = torch.nn.Sequential(*seq) # model.load_state_dict(torch.load(model_path), strict=False) model.load_state_dict(torch.load(model_path, map_location ='cpu'), strict=False) # model.load_state_dict(torch.load(model_path)) model.eval() return model def test(): data_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize([int(224*1.15), int(224*1.15)]), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) origin_dataset = dataset.CARS("/data/few-shot/STANFORD-CARS/", transform=data_transform) #origin_dataset = dataset.ImprovedImageFolder("/data/few-shot/mini_imagenet_full_size/train", transform=data_transform) N = torch.randint(5, 10, (1,)).tolist()[0] K = torch.randint(1, 10, (1,)).tolist()[0] batch_size = 2 episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 N, # N K, # K 100 # 任务数目 ) print(episodic_dataset) data_loader = DataLoader(episodic_dataset, batch_size=batch_size, pin_memory=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] st = time.time() for item in data_loader: item = convert_tensor(item, device, non_blocking=True) # item["query"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H # item["target"]: B x NK query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) accs.append(evaluate(query_batch, item["target"], support_batch)) print("time: ", time.time()-st) st = time.time() print(torch.tensor(accs).mean().item()) if __name__ == '__main__': setup_seed(100) test()