32 lines
859 B
Python
32 lines
859 B
Python
from pathlib import Path
|
|
|
|
import torch
|
|
#
|
|
# data = {}
|
|
#
|
|
# for i in range(1, 422 + 1):
|
|
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
|
|
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
|
|
# print(len(names))
|
|
# for j, n in enumerate(names):
|
|
# data[Path(names[j]).stem] = generated[j]
|
|
#
|
|
# torch.save(data, "/tmp/data.pt")
|
|
|
|
data = torch.load("/tmp/data.pt")
|
|
|
|
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
|
|
|
|
for idx in range(len(videos)):
|
|
print(videos[idx])
|
|
videos_data = {}
|
|
for k in data:
|
|
if k.startswith(videos[idx]):
|
|
videos_data[int(k.split("@")[-1])] = data[k]
|
|
|
|
to_save = []
|
|
for i in range(2, len(videos_data) + 1):
|
|
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
|
|
|
|
torch.save(to_save, f"{videos[idx]}.pt")
|
|
print(f"{videos[idx]}.pt") |