import numpy as np from skimage import feature from pathlib import Path from torchvision.datasets.folder import is_image_file, default_loader from torchvision.transforms import functional as F from loss.I2I.edge_loss import HED import torch from PIL import Image import fire def canny_edge(img): edge = feature.canny(np.array(img.convert("L"))) return edge def generate(image_folder, edge_type, save_folder, device="cuda:0"): assert edge_type in ["canny", "hed"] image_folder = Path(image_folder) save_folder = Path(save_folder) if edge_type == "hed": edge_extractor = HED("/root/network-bsds500.pytorch", norm_img=False).to(device) elif edge_type == "canny": edge_extractor = canny_edge else: raise NotImplemented for p in image_folder.glob("*"): if is_image_file(p.as_posix()): rgb_img = default_loader(p) print(p) if edge_type == "hed": with torch.no_grad(): img_tensor = F.to_tensor(rgb_img).to(device) edge_tensor = edge_extractor(img_tensor) edge = F.to_pil_image(edge_tensor.clamp(0, 1.0).squeeze().detach().cpu()) edge.save(save_folder / f"{p.stem}.{edge_type}.png") elif edge_type == "canny": edge = edge_extractor(rgb_img) Image.fromarray(edge).save(save_folder / f"{p.stem}.{edge_type}.png") else: raise NotImplemented if __name__ == '__main__': fire.Fire(generate)