import torch from torch.utils.data import DataLoader import torchvision from data import dataset import torch.nn as nn from ignite.utils import convert_tensor def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def euclidean_dist(x, y): """ Compute euclidean distance between two tensors """ # x: B x N x D # y: B x M x D n = x.size(-2) m = y.size(-2) d = x.size(-1) if d != y.size(-1): raise Exception x = x.unsqueeze(2).expand(x.size(0), n, m, d) # B x N x M x D y = y.unsqueeze(1).expand(x.size(0), n, m, d) return torch.pow(x - y, 2).sum(-1) def evaluate(query, target, support): """ :param query: B x NK x D vector :param target: B x NK x 1 vector :param support: B x N x K x D vector :return: """ K = support.size(-2) prototypes = support.mean(-2) # B x N x D distance = euclidean_dist(query, prototypes) # B x NK x N print(distance.shape) indices = distance.argmin(-1) # B x NK print(indices, target) return torch.eq(target, indices).float().mean() class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.size(0), -1) # def make_extractor(): # resnet50 = torchvision.models.resnet50(pretrained=True) # resnet50.to(torch.device("cuda")) # resnet50.fc = torch.nn.Identity() # resnet50.eval() # # def extract(images): # with torch.no_grad(): # return resnet50(images) # return extract def make_extractor(): model = resnet18() model.to(torch.device("cuda")) model.eval() def extract(images): with torch.no_grad(): return model(images) return extract def resnet18(model_path="ResNet18Official.pth"): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_w_fc = torchvision.models.resnet18(pretrained=False) seq = list(model_w_fc.children())[:-1] seq.append(Flatten()) model = torch.nn.Sequential(*seq) # model.load_state_dict(torch.load(model_path), strict=False) model.load_state_dict(torch.load(model_path, map_location ='cpu'), strict=False) # model.load_state_dict(torch.load(model_path)) model.eval() return model def test(): data_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize([int(224*1.15), int(224*1.15)]), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) origin_dataset = dataset.CARS("/data/few-shot/STANFORD-CARS/", transform=data_transform) #origin_dataset = dataset.ImprovedImageFolder("/data/few-shot/mini_imagenet_full_size/train", transform=data_transform) N = torch.randint(5, 10, (1,)).tolist()[0] K = torch.randint(1, 10, (1,)).tolist()[0] batch_size = 2 episodic_dataset = dataset.EpisodicDataset( origin_dataset, # 抽取数据集 N, # N K, # K 100 # 任务数目 ) print(episodic_dataset) data_loader = DataLoader(episodic_dataset, batch_size=batch_size) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") extractor = make_extractor() accs = [] for item in data_loader: item = convert_tensor(item, device) # item["query"]: B x NK x 3 x W x H # item["support"]: B x NK x 3 x W x H # item["target"]: B x NK print(item["support"].shape, item["target"].shape) query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N*K, -1) support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1) accs.append(evaluate(query_batch, item["target"], support_batch)) print(torch.tensor(accs).mean().item()) if __name__ == '__main__': setup_seed(100) test()