add distance
This commit is contained in:
parent
85b5c3f589
commit
340a344e91
32
tool/encoder_distance.py
Normal file
32
tool/encoder_distance.py
Normal file
@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
#
|
||||
# data = {}
|
||||
#
|
||||
# for i in range(1, 422 + 1):
|
||||
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
|
||||
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
|
||||
# print(len(names))
|
||||
# for j, n in enumerate(names):
|
||||
# data[Path(names[j]).stem] = generated[j]
|
||||
#
|
||||
# torch.save(data, "/tmp/data.pt")
|
||||
|
||||
data = torch.load("/tmp/data.pt")
|
||||
|
||||
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
|
||||
|
||||
for idx in range(len(videos)):
|
||||
print(videos[idx])
|
||||
videos_data = {}
|
||||
for k in data:
|
||||
if k.startswith(videos[idx]):
|
||||
videos_data[int(k.split("@")[-1])] = data[k]
|
||||
|
||||
to_save = []
|
||||
for i in range(2, len(videos_data) + 1):
|
||||
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
|
||||
|
||||
torch.save(to_save, f"{videos[idx]}.pt")
|
||||
print(f"{videos[idx]}.pt")
|
||||
Loading…
Reference in New Issue
Block a user