raycv/tool/verify_loss.py
2020-09-05 22:00:17 +08:00

70 lines
1.9 KiB
Python

import torch
from torch.utils.data import DataLoader
from ignite.utils import convert_tensor
from omegaconf import OmegaConf
from data.dataset import SingleFolderDataset
from loss.I2I.perceptual_loss import PerceptualLoss
import ignite.distributed as idist
CONFIG = """
loss:
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'NL2'
style_loss: False
perceptual_loss: True
match_data:
root: "/tmp/generated/"
pipeline:
- Load
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
not_match_data:
root: "/data/i2i/selfie2anime/trainB/"
pipeline:
- Load
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
"""
config = OmegaConf.create(CONFIG)
dataset = SingleFolderDataset(**config.match_data)
data_loader = DataLoader(dataset, 1, False, num_workers=1)
perceptual_loss = PerceptualLoss(**config.loss.perceptual).to("cuda:0")
pls = []
for batch in data_loader:
with torch.no_grad():
batch = convert_tensor(batch, "cuda:0")
x, t = torch.chunk(batch, 2, -1)
pl, _ = perceptual_loss(x, t)
print(pl)
pls.append(pl)
torch.save(torch.stack(pls).cpu(), "verify_loss.match.pt")
dataset = SingleFolderDataset(**config.not_match_data)
data_loader = DataLoader(dataset, 4, False, num_workers=1)
pls = []
for batch in data_loader:
with torch.no_grad():
batch = convert_tensor(batch, "cuda:0")
for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]:
x, t = batch[i].unsqueeze(dim=0), batch[j].unsqueeze(dim=0)
pl, _ = perceptual_loss(x, t)
print(pl)
pls.append(pl)
torch.save(torch.stack(pls).cpu(), "verify_loss.not_match.pt")