add distance
This commit is contained in:
parent
85b5c3f589
commit
340a344e91
32
tool/encoder_distance.py
Normal file
32
tool/encoder_distance.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
#
|
||||||
|
# data = {}
|
||||||
|
#
|
||||||
|
# for i in range(1, 422 + 1):
|
||||||
|
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
|
||||||
|
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
|
||||||
|
# print(len(names))
|
||||||
|
# for j, n in enumerate(names):
|
||||||
|
# data[Path(names[j]).stem] = generated[j]
|
||||||
|
#
|
||||||
|
# torch.save(data, "/tmp/data.pt")
|
||||||
|
|
||||||
|
data = torch.load("/tmp/data.pt")
|
||||||
|
|
||||||
|
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
|
||||||
|
|
||||||
|
for idx in range(len(videos)):
|
||||||
|
print(videos[idx])
|
||||||
|
videos_data = {}
|
||||||
|
for k in data:
|
||||||
|
if k.startswith(videos[idx]):
|
||||||
|
videos_data[int(k.split("@")[-1])] = data[k]
|
||||||
|
|
||||||
|
to_save = []
|
||||||
|
for i in range(2, len(videos_data) + 1):
|
||||||
|
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
|
||||||
|
|
||||||
|
torch.save(to_save, f"{videos[idx]}.pt")
|
||||||
|
print(f"{videos[idx]}.pt")
|
||||||
Loading…
Reference in New Issue
Block a user