add context loss
This commit is contained in:
parent
b01016edb5
commit
ca55318253
44
loss/I2I/context_loss.py
Normal file
44
loss/I2I/context_loss.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .perceptual_loss import PerceptualVGG
|
||||
|
||||
|
||||
class ContextLoss(nn.Module):
|
||||
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
|
||||
eps=1e-5):
|
||||
super(ContextLoss, self).__init__()
|
||||
self.eps = eps
|
||||
self.h = h
|
||||
self.layer_weights = layer_weights
|
||||
self.norm_img = norm_img
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
def single_forward(self, source_feature, target_feature):
|
||||
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
|
||||
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
source_feature = F.normalize(source_feature, p=2, dim=1)
|
||||
target_feature = F.normalize(target_feature, p=2, dim=1)
|
||||
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
|
||||
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
|
||||
w = torch.exp((1 - rel_distance) / self.h)
|
||||
cx = w.div(w.sum(dim=2, keepdim=True))
|
||||
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
|
||||
return -torch.log(cx).mean()
|
||||
|
||||
def forward(self, x, gt):
|
||||
if self.norm_img:
|
||||
x = (x + 1.) * 0.5
|
||||
gt = (gt + 1.) * 0.5
|
||||
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
loss = 0
|
||||
for k in x_features.keys():
|
||||
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
return loss
|
||||
@ -4,6 +4,49 @@ import torch.nn.functional as F
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
# Sequential(
|
||||
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (1): ReLU(inplace=True)
|
||||
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (3): ReLU(inplace=True)
|
||||
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (6): ReLU(inplace=True)
|
||||
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (8): ReLU(inplace=True)
|
||||
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (11): ReLU(inplace=True)
|
||||
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (13): ReLU(inplace=True)
|
||||
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (15): ReLU(inplace=True)
|
||||
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (17): ReLU(inplace=True)
|
||||
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (20): ReLU(inplace=True)
|
||||
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (22): ReLU(inplace=True)
|
||||
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (24): ReLU(inplace=True)
|
||||
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (26): ReLU(inplace=True)
|
||||
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (29): ReLU(inplace=True)
|
||||
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (31): ReLU(inplace=True)
|
||||
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (33): ReLU(inplace=True)
|
||||
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (35): ReLU(inplace=True)
|
||||
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
# )
|
||||
class PerceptualVGG(nn.Module):
|
||||
"""VGG network used in calculating perceptual loss.
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
|
||||
list contains the name each layer in `vgg.feature`. An example
|
||||
of this list is ['4', '10'].
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image.
|
||||
Importantly, the input feature must in the range [0, 1].
|
||||
Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
|
||||
super(PerceptualVGG, self).__init__()
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
self.use_input_norm = norm_image_with_imagenet_param
|
||||
|
||||
# get vgg model and load pretrained vgg weight
|
||||
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
||||
@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
|
||||
in calculating losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
||||
loss will be calculated.
|
||||
@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module):
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
|
||||
style_loss=False, norm_img=True, criterion='L1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.norm_img = norm_img
|
||||
@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module):
|
||||
self.style_loss = style_loss
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user