130 lines
5.8 KiB
Python
130 lines
5.8 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class HED(nn.Module):
|
|
def __init__(self, pretrained_model_path, norm_img=True):
|
|
"""
|
|
HED module to get edge
|
|
:param pretrained_model_path: path to pretrained HED.
|
|
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
|
|
this is different from the `use_input_norm` which norm the input in
|
|
in forward function of vgg according to the statistics of dataset.
|
|
Importantly, the input image must be in range [-1, 1].
|
|
"""
|
|
super().__init__()
|
|
self.norm_img = norm_img
|
|
|
|
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
|
|
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False)
|
|
), torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False)
|
|
), torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False)
|
|
), torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False)
|
|
), torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
|
torch.nn.ReLU(inplace=False)
|
|
)])
|
|
|
|
self.score_nets = nn.ModuleList([
|
|
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
|
|
for i in [64, 128, 256, 512, 512]
|
|
])
|
|
|
|
self.combine_net = torch.nn.Sequential(
|
|
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
|
torch.nn.Sigmoid()
|
|
)
|
|
|
|
self.load_weights(pretrained_model_path)
|
|
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
|
|
for v in self.parameters():
|
|
v.requies_grad = False
|
|
|
|
def load_weights(self, pretrained_model_path):
|
|
checkpoint_path = Path(pretrained_model_path)
|
|
if not checkpoint_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
|
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
|
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
|
|
|
|
def replace_key(key):
|
|
if key.startswith("moduleVgg"):
|
|
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
|
|
elif key.startswith("moduleScore"):
|
|
return f"score_nets.{m[key[11:14]]}{key[14:]}"
|
|
elif key.startswith("moduleCombine"):
|
|
return f"combine_net{key[13:]}"
|
|
else:
|
|
raise ValueError("wrong checkpoint for HED")
|
|
|
|
module_dict = {replace_key(k): v for k, v in ckp.items()}
|
|
self.load_state_dict(module_dict, strict=True)
|
|
|
|
def forward(self, x):
|
|
if self.norm_img:
|
|
x = (x + 1.) * 0.5
|
|
x = x * 255.0 - self.mean
|
|
img_size = (x.size(2), x.size(3))
|
|
|
|
to_combine = []
|
|
for i in range(5):
|
|
x = self.vgg_nets[i](x)
|
|
score_x = self.score_nets[i](x)
|
|
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
|
|
out = self.combine_net(torch.cat(to_combine, 1))
|
|
return out.clamp(0.0, 1.0)
|
|
|
|
|
|
class EdgeLoss(nn.Module):
|
|
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
|
|
super(EdgeLoss, self).__init__()
|
|
if edge_extractor_type == "HED":
|
|
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
|
|
self.edge_extractor = HED(pretrained_model_path, norm_img)
|
|
else:
|
|
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
|
|
|
|
if criterion == 'L1':
|
|
self.criterion = nn.L1Loss()
|
|
elif criterion == "L2":
|
|
self.criterion = nn.MSELoss()
|
|
else:
|
|
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
|
|
|
def forward(self, x, gt, gt_is_edge=True):
|
|
edge = self.edge_extractor(x)
|
|
if not gt_is_edge:
|
|
gt = self.edge_extractor(gt.detach())
|
|
loss = self.criterion(edge, gt)
|
|
return loss
|