raycv/tool/process/generate_edge.py
2020-08-30 14:44:40 +08:00

46 lines
1.5 KiB
Python

import numpy as np
from skimage import feature
from pathlib import Path
from torchvision.datasets.folder import is_image_file, default_loader
from torchvision.transforms import functional as F
from loss.I2I.edge_loss import HED
import torch
from PIL import Image
import fire
def canny_edge(img):
edge = feature.canny(np.array(img.convert("L")))
return edge
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
assert edge_type in ["canny", "hed"]
image_folder = Path(image_folder)
save_folder = Path(save_folder)
if edge_type == "hed":
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
elif edge_type == "canny":
edge_extractor = canny_edge
else:
raise NotImplemented
for p in image_folder.glob("*"):
if is_image_file(p.as_posix()):
rgb_img = default_loader(p)
print(p)
if edge_type == "hed":
with torch.no_grad():
img_tensor = F.to_tensor(rgb_img).to(device)
edge_tensor = edge_extractor(img_tensor)
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
elif edge_type == "canny":
edge = edge_extractor(rgb_img)
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
else:
raise NotImplemented
if __name__ == '__main__':
fire.Fire(generate)