add distance

This commit is contained in:
budui 2020-09-11 22:35:59 +08:00
parent 85b5c3f589
commit 340a344e91

32
tool/encoder_distance.py Normal file
View File

@ -0,0 +1,32 @@
from pathlib import Path
import torch
#
# data = {}
#
# for i in range(1, 422 + 1):
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
# print(len(names))
# for j, n in enumerate(names):
# data[Path(names[j]).stem] = generated[j]
#
# torch.save(data, "/tmp/data.pt")
data = torch.load("/tmp/data.pt")
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
for idx in range(len(videos)):
print(videos[idx])
videos_data = {}
for k in data:
if k.startswith(videos[idx]):
videos_data[int(k.split("@")[-1])] = data[k]
to_save = []
for i in range(2, len(videos_data) + 1):
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
torch.save(to_save, f"{videos[idx]}.pt")
print(f"{videos[idx]}.pt")