add context loss

This commit is contained in:
Ray Wong 2020-09-24 16:38:03 +08:00
parent b01016edb5
commit ca55318253
2 changed files with 93 additions and 6 deletions

44
loss/I2I/context_loss.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from .perceptual_loss import PerceptualVGG
class ContextLoss(nn.Module):
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
eps=1e-5):
super(ContextLoss, self).__init__()
self.eps = eps
self.h = h
self.layer_weights = layer_weights
self.norm_img = norm_img
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
def single_forward(self, source_feature, target_feature):
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
source_feature = F.normalize(source_feature, p=2, dim=1)
target_feature = F.normalize(target_feature, p=2, dim=1)
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
w = torch.exp((1 - rel_distance) / self.h)
cx = w.div(w.sum(dim=2, keepdim=True))
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
return -torch.log(cx).mean()
def forward(self, x, gt):
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
loss = 0
for k in x_features.keys():
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
return loss

View File

@ -4,6 +4,49 @@ import torch.nn.functional as F
import torchvision.models.vgg as vgg import torchvision.models.vgg as vgg
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (17): ReLU(inplace=True)
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (24): ReLU(inplace=True)
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (26): ReLU(inplace=True)
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (29): ReLU(inplace=True)
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (31): ReLU(inplace=True)
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (33): ReLU(inplace=True)
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (35): ReLU(inplace=True)
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
class PerceptualVGG(nn.Module): class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss. """VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization In this implementation, we allow users to choose whether use normalization
@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
list contains the name each layer in `vgg.feature`. An example list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10']. of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'. vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. norm_image_with_imagenet_param (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1]. Importantly, the input feature must in the range [0, 1].
Default: True. Default: True.
""" """
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True): def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
super(PerceptualVGG, self).__init__() super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm self.use_input_norm = norm_image_with_imagenet_param
# get vgg model and load pretrained vgg weight # get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug # remove _vgg from attributes to avoid `find_unused_parameters` bug
@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
in calculating losses. in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor. vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg. norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
Default: True. Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated. loss will be calculated.
@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module):
Importantly, the input image must be in range [-1, 1]. Importantly, the input image must be in range [-1, 1].
""" """
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True, def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'): style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__() super(PerceptualLoss, self).__init__()
self.norm_img = norm_img self.norm_img = norm_img
@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module):
self.style_loss = style_loss self.style_loss = style_loss
self.layer_weights = layer_weights self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm) norm_image_with_imagenet_param=norm_image_with_imagenet_param)
self.percep_criterion, self.style_criterion = self.set_criterion(criterion) self.percep_criterion, self.style_criterion = self.set_criterion(criterion)