raycv/loss/I2I/edge_loss.py
2020-08-30 09:34:23 +08:00

130 lines
5.8 KiB
Python

from pathlib import Path
import torch
import torch.nn as nn
from torch.nn import functional as F
class HED(nn.Module):
def __init__(self, pretrained_model_path, norm_img=True):
"""
HED module to get edge
:param pretrained_model_path: path to pretrained HED.
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
"""
super().__init__()
self.norm_img = norm_img
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)])
self.score_nets = nn.ModuleList([
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
for i in [64, 128, 256, 512, 512]
])
self.combine_net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
torch.nn.Sigmoid()
)
self.load_weights(pretrained_model_path)
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
for v in self.parameters():
v.requies_grad = False
def load_weights(self, pretrained_model_path):
checkpoint_path = Path(pretrained_model_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
def replace_key(key):
if key.startswith("moduleVgg"):
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
elif key.startswith("moduleScore"):
return f"score_nets.{m[key[11:14]]}{key[14:]}"
elif key.startswith("moduleCombine"):
return f"combine_net{key[13:]}"
else:
raise ValueError("wrong checkpoint for HED")
module_dict = {replace_key(k): v for k, v in ckp.items()}
self.load_state_dict(module_dict, strict=True)
def forward(self, x):
if self.norm_img:
x = (x + 1.) * 0.5
x = x * 255.0 - self.mean
img_size = (x.size(2), x.size(3))
to_combine = []
for i in range(5):
x = self.vgg_nets[i](x)
score_x = self.score_nets[i](x)
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
out = self.combine_net(torch.cat(to_combine, 1))
return out.clamp(0.0, 1.0)
class EdgeLoss(nn.Module):
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
super(EdgeLoss, self).__init__()
if edge_extractor_type == "HED":
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
self.edge_extractor = HED(pretrained_model_path, norm_img)
else:
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
if criterion == 'L1':
self.criterion = nn.L1Loss()
elif criterion == "L2":
self.criterion = nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
def forward(self, x, gt, gt_is_edge=True):
edge = self.edge_extractor(x)
if not gt_is_edge:
gt = self.edge_extractor(gt.detach())
loss = self.criterion(edge, gt)
return loss