From ca553182534dda462813e43204afb5b351c58697 Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Thu, 24 Sep 2020 16:38:03 +0800 Subject: [PATCH] add context loss --- loss/I2I/context_loss.py | 44 +++++++++++++++++++++++++++++ loss/I2I/perceptual_loss.py | 55 +++++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 loss/I2I/context_loss.py diff --git a/loss/I2I/context_loss.py b/loss/I2I/context_loss.py new file mode 100644 index 0000000..464c7c6 --- /dev/null +++ b/loss/I2I/context_loss.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from .perceptual_loss import PerceptualVGG + + +class ContextLoss(nn.Module): + def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True, + eps=1e-5): + super(ContextLoss, self).__init__() + self.eps = eps + self.h = h + self.layer_weights = layer_weights + self.norm_img = norm_img + self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, + norm_image_with_imagenet_param=norm_image_with_imagenet_param) + + def single_forward(self, source_feature, target_feature): + mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True) + source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW + target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW + source_feature = F.normalize(source_feature, p=2, dim=1) + target_feature = F.normalize(target_feature, p=2, dim=1) + cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW + rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps) + w = torch.exp((1 - rel_distance) / self.h) + cx = w.div(w.sum(dim=2, keepdim=True)) + cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2) + return -torch.log(cx).mean() + + def forward(self, x, gt): + if self.norm_img: + x = (x + 1.) * 0.5 + gt = (gt + 1.) * 0.5 + + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + loss = 0 + for k in x_features.keys(): + loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k] + return loss diff --git a/loss/I2I/perceptual_loss.py b/loss/I2I/perceptual_loss.py index e55aaa1..dd44deb 100644 --- a/loss/I2I/perceptual_loss.py +++ b/loss/I2I/perceptual_loss.py @@ -4,6 +4,49 @@ import torch.nn.functional as F import torchvision.models.vgg as vgg +# Sequential( +# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (1): ReLU(inplace=True) +# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (3): ReLU(inplace=True) +# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + +# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (6): ReLU(inplace=True) +# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (8): ReLU(inplace=True) +# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + +# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (11): ReLU(inplace=True) +# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (13): ReLU(inplace=True) +# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (15): ReLU(inplace=True) +# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (17): ReLU(inplace=True) +# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + +# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (20): ReLU(inplace=True) +# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (22): ReLU(inplace=True) +# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (24): ReLU(inplace=True) +# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (26): ReLU(inplace=True) +# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + +# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (29): ReLU(inplace=True) +# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (31): ReLU(inplace=True) +# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (33): ReLU(inplace=True) +# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) +# (35): ReLU(inplace=True) +# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) +# ) class PerceptualVGG(nn.Module): """VGG network used in calculating perceptual loss. In this implementation, we allow users to choose whether use normalization @@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module): list contains the name each layer in `vgg.feature`. An example of this list is ['4', '10']. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image. + norm_image_with_imagenet_param (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. """ - def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True): + def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True): super(PerceptualVGG, self).__init__() self.layer_name_list = layer_name_list - self.use_input_norm = use_input_norm + self.use_input_norm = norm_image_with_imagenet_param # get vgg model and load pretrained vgg weight # remove _vgg from attributes to avoid `find_unused_parameters` bug @@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module): in calculating losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image in vgg. + norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg. Default: True. perceptual_loss (bool): If `perceptual_loss == True`, the perceptual loss will be calculated. @@ -88,7 +131,7 @@ class PerceptualLoss(nn.Module): Importantly, the input image must be in range [-1, 1]. """ - def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True, + def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True, style_loss=False, norm_img=True, criterion='L1'): super(PerceptualLoss, self).__init__() self.norm_img = norm_img @@ -97,7 +140,7 @@ class PerceptualLoss(nn.Module): self.style_loss = style_loss self.layer_weights = layer_weights self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, - use_input_norm=use_input_norm) + norm_image_with_imagenet_param=norm_image_with_imagenet_param) self.percep_criterion, self.style_criterion = self.set_criterion(criterion)