add context loss

This commit is contained in:
Ray Wong 2020-09-24 16:38:03 +08:00
parent b01016edb5
commit ca55318253
2 changed files with 93 additions and 6 deletions

44
loss/I2I/context_loss.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from .perceptual_loss import PerceptualVGG
class ContextLoss(nn.Module):
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
eps=1e-5):
super(ContextLoss, self).__init__()
self.eps = eps
self.h = h
self.layer_weights = layer_weights
self.norm_img = norm_img
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
def single_forward(self, source_feature, target_feature):
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
source_feature = F.normalize(source_feature, p=2, dim=1)
target_feature = F.normalize(target_feature, p=2, dim=1)
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
w = torch.exp((1 - rel_distance) / self.h)
cx = w.div(w.sum(dim=2, keepdim=True))
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
return -torch.log(cx).mean()
def forward(self, x, gt):
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
loss = 0
for k in x_features.keys():
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
return loss

View File

@ -4,6 +4,49 @@ import torch.nn.functional as F
import torchvision.models.vgg as vgg
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (17): ReLU(inplace=True)
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (24): ReLU(inplace=True)
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (26): ReLU(inplace=True)
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (29): ReLU(inplace=True)
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (31): ReLU(inplace=True)
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (33): ReLU(inplace=True)
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (35): ReLU(inplace=True)
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
norm_image_with_imagenet_param (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.use_input_norm = norm_image_with_imagenet_param
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module):
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module):
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)