from pathlib import Path import torch import torch.nn as nn from torch.nn import functional as F class HED(nn.Module): def __init__(self, pretrained_model_path, norm_img=True): """ HED module to get edge :param pretrained_model_path: path to pretrained HED. :param norm_img(bool): If True, the image will be normed to [0, 1]. Note that this is different from the `use_input_norm` which norm the input in in forward function of vgg according to the statistics of dataset. Importantly, the input image must be in range [-1, 1]. """ super().__init__() self.norm_img = norm_img self.vgg_nets = nn.ModuleList([torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False) ), torch.nn.Sequential( torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False) ), torch.nn.Sequential( torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False) ), torch.nn.Sequential( torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False) ), torch.nn.Sequential( torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False) )]) self.score_nets = nn.ModuleList([ torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0) for i in [64, 128, 256, 512, 512] ]) self.combine_net = torch.nn.Sequential( torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), torch.nn.Sigmoid() ) self.load_weights(pretrained_model_path) self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1)) for v in self.parameters(): v.requies_grad = False def load_weights(self, pretrained_model_path): checkpoint_path = Path(pretrained_model_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found") ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu") m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"} def replace_key(key): if key.startswith("moduleVgg"): return f"vgg_nets.{m[key[9:12]]}{key[12:]}" elif key.startswith("moduleScore"): return f"score_nets.{m[key[11:14]]}{key[14:]}" elif key.startswith("moduleCombine"): return f"combine_net{key[13:]}" else: raise ValueError("wrong checkpoint for HED") module_dict = {replace_key(k): v for k, v in ckp.items()} self.load_state_dict(module_dict, strict=True) def forward(self, x): if self.norm_img: x = (x + 1.) * 0.5 x = x * 255.0 - self.mean img_size = (x.size(2), x.size(3)) to_combine = [] for i in range(5): x = self.vgg_nets[i](x) score_x = self.score_nets[i](x) to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False)) out = self.combine_net(torch.cat(to_combine, 1)) return out.clamp(0.0, 1.0) class EdgeLoss(nn.Module): def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs): super(EdgeLoss, self).__init__() if edge_extractor_type == "HED": pretrained_model_path = kwargs.get("hed_pretrained_model_path") self.edge_extractor = HED(pretrained_model_path, norm_img) else: raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}") if criterion == 'L1': self.criterion = nn.L1Loss() elif criterion == "L2": self.criterion = nn.MSELoss() else: raise NotImplementedError(f'{criterion} criterion has not been supported in this version.') def forward(self, x, gt, gt_is_edge=True): edge = self.edge_extractor(x) if not gt_is_edge: gt = self.edge_extractor(gt.detach()) loss = self.criterion(edge, gt) return loss