45 lines
1.9 KiB
Python
45 lines
1.9 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from .perceptual_loss import PerceptualVGG
|
|
|
|
|
|
class ContextLoss(nn.Module):
|
|
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
|
|
eps=1e-5):
|
|
super(ContextLoss, self).__init__()
|
|
self.eps = eps
|
|
self.h = h
|
|
self.layer_weights = layer_weights
|
|
self.norm_img = norm_img
|
|
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
|
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
|
|
|
def single_forward(self, source_feature, target_feature):
|
|
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
|
|
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
|
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
|
source_feature = F.normalize(source_feature, p=2, dim=1)
|
|
target_feature = F.normalize(target_feature, p=2, dim=1)
|
|
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
|
|
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
|
|
w = torch.exp((1 - rel_distance) / self.h)
|
|
cx = w.div(w.sum(dim=2, keepdim=True))
|
|
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
|
|
return -torch.log(cx).mean()
|
|
|
|
def forward(self, x, gt):
|
|
if self.norm_img:
|
|
x = (x + 1.) * 0.5
|
|
gt = (gt + 1.) * 0.5
|
|
|
|
# extract vgg features
|
|
x_features = self.vgg(x)
|
|
gt_features = self.vgg(gt.detach())
|
|
|
|
loss = 0
|
|
for k in x_features.keys():
|
|
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
|
|
return loss
|