import torch import torch.nn.functional as F from torch import nn from .perceptual_loss import PerceptualVGG class ContextLoss(nn.Module): def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True, eps=1e-5): super(ContextLoss, self).__init__() self.eps = eps self.h = h self.layer_weights = layer_weights self.norm_img = norm_img self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, norm_image_with_imagenet_param=norm_image_with_imagenet_param) def single_forward(self, source_feature, target_feature): mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True) source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW source_feature = F.normalize(source_feature, p=2, dim=1) target_feature = F.normalize(target_feature, p=2, dim=1) cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps) w = torch.exp((1 - rel_distance) / self.h) cx = w.div(w.sum(dim=2, keepdim=True)) cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2) return -torch.log(cx).mean() def forward(self, x, gt): if self.norm_img: x = (x + 1.) * 0.5 gt = (gt + 1.) * 0.5 # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) loss = 0 for k in x_features.keys(): loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k] return loss