46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
import numpy as np
|
|
from skimage import feature
|
|
from pathlib import Path
|
|
from torchvision.datasets.folder import is_image_file, default_loader
|
|
from torchvision.transforms import functional as F
|
|
from loss.I2I.edge_loss import HED
|
|
import torch
|
|
from PIL import Image
|
|
import fire
|
|
|
|
|
|
def canny_edge(img):
|
|
edge = feature.canny(np.array(img.convert("L")))
|
|
return edge
|
|
|
|
|
|
def generate(image_folder, edge_type, save_folder, device="cuda:0"):
|
|
assert edge_type in ["canny", "hed"]
|
|
image_folder = Path(image_folder)
|
|
save_folder = Path(save_folder)
|
|
if edge_type == "hed":
|
|
edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device)
|
|
elif edge_type == "canny":
|
|
edge_extractor = canny_edge
|
|
else:
|
|
raise NotImplemented
|
|
for p in image_folder.glob("*"):
|
|
if is_image_file(p.as_posix()):
|
|
rgb_img = default_loader(p)
|
|
print(p)
|
|
if edge_type == "hed":
|
|
with torch.no_grad():
|
|
img_tensor = F.to_tensor(rgb_img).to(device)
|
|
edge_tensor = edge_extractor(img_tensor)
|
|
edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu())
|
|
edge.save(save_folder / f"{p.stem}.{edge_type}.png")
|
|
elif edge_type == "canny":
|
|
edge = edge_extractor(rgb_img)
|
|
Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png")
|
|
else:
|
|
raise NotImplemented
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fire.Fire(generate)
|