156 lines
5.9 KiB
Python
156 lines
5.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision.models.vgg as vgg
|
|
|
|
|
|
class PerceptualVGG(nn.Module):
|
|
"""VGG network used in calculating perceptual loss.
|
|
In this implementation, we allow users to choose whether use normalization
|
|
in the input feature and the type of vgg network. Note that the pretrained
|
|
path must fit the vgg type.
|
|
Args:
|
|
layer_name_list (list[str]): According to the index in this list,
|
|
forward function will return the corresponding features. This
|
|
list contains the name each layer in `vgg.feature`. An example
|
|
of this list is ['4', '10'].
|
|
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
|
use_input_norm (bool): If True, normalize the input image.
|
|
Importantly, the input feature must in the range [0, 1].
|
|
Default: True.
|
|
"""
|
|
|
|
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
|
super(PerceptualVGG, self).__init__()
|
|
self.layer_name_list = layer_name_list
|
|
self.use_input_norm = use_input_norm
|
|
|
|
# get vgg model and load pretrained vgg weight
|
|
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
|
_vgg = getattr(vgg, vgg_type)(pretrained=True)
|
|
num_layers = max(map(int, layer_name_list)) + 1
|
|
assert len(_vgg.features) >= num_layers
|
|
# only borrow layers that will be used from _vgg to avoid unused params
|
|
self.vgg_layers = _vgg.features[:num_layers]
|
|
|
|
if self.use_input_norm:
|
|
# the mean is for image with range [0, 1]
|
|
self.register_buffer(
|
|
'mean',
|
|
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
# the std is for image with range [-1, 1]
|
|
self.register_buffer(
|
|
'std',
|
|
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
|
|
for v in self.vgg_layers.parameters():
|
|
v.requies_grad = False
|
|
|
|
def forward(self, x):
|
|
"""Forward function.
|
|
Args:
|
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
|
Returns:
|
|
Tensor: Forward results.
|
|
"""
|
|
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
output = {}
|
|
|
|
for i, l in enumerate(self.vgg_layers):
|
|
x = l(x)
|
|
if str(i) in self.layer_name_list:
|
|
output[str(i)] = x.clone()
|
|
|
|
return output
|
|
|
|
|
|
class PerceptualLoss(nn.Module):
|
|
"""Perceptual loss with commonly used style loss.
|
|
Args:
|
|
layer_weights (dict): The weight for each layer of vgg feature.
|
|
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
|
|
5th, 10th and 18th feature layer will be extracted with weight 1.0
|
|
in calculating losses.
|
|
vgg_type (str): The type of vgg network used as feature extractor.
|
|
Default: 'vgg19'.
|
|
use_input_norm (bool): If True, normalize the input image in vgg.
|
|
Default: True.
|
|
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
|
loss will be calculated.
|
|
Default: True.
|
|
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
|
|
Default: False.
|
|
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
|
|
this is different from the `use_input_norm` which norm the input in
|
|
in forward function of vgg according to the statistics of dataset.
|
|
Importantly, the input image must be in range [-1, 1].
|
|
"""
|
|
|
|
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
|
style_loss=False, norm_img=True, criterion='L1'):
|
|
super(PerceptualLoss, self).__init__()
|
|
self.norm_img = norm_img
|
|
self.perceptual_loss = perceptual_loss
|
|
self.style_loss = style_loss
|
|
self.layer_weights = layer_weights
|
|
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
|
use_input_norm=use_input_norm)
|
|
|
|
self.criterion = self.set_criterion(criterion)
|
|
|
|
def set_criterion(self, criterion: str):
|
|
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
|
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
|
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
|
return lambda x, t: fn(norm(x), norm(t))
|
|
|
|
def forward(self, x, gt):
|
|
"""Forward function.
|
|
Args:
|
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
|
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
|
Returns:
|
|
Tensor: Forward results.
|
|
"""
|
|
|
|
if self.norm_img:
|
|
x = (x + 1.) * 0.5
|
|
gt = (gt + 1.) * 0.5
|
|
# extract vgg features
|
|
x_features = self.vgg(x)
|
|
gt_features = self.vgg(gt.detach())
|
|
|
|
# calculate preceptual loss
|
|
if self.perceptual_loss:
|
|
percep_loss = 0
|
|
for k in x_features.keys():
|
|
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
|
else:
|
|
percep_loss = None
|
|
|
|
# calculate style loss
|
|
if self.style_loss:
|
|
style_loss = 0
|
|
for k in x_features.keys():
|
|
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
|
self.layer_weights[k]
|
|
else:
|
|
style_loss = None
|
|
|
|
return percep_loss, style_loss
|
|
|
|
def _gram_mat(self, x):
|
|
"""Calculate Gram matrix.
|
|
Args:
|
|
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
|
Returns:
|
|
torch.Tensor: Gram matrix.
|
|
"""
|
|
(n, c, h, w) = x.size()
|
|
features = x.view(n, c, w * h)
|
|
features_t = features.transpose(1, 2)
|
|
gram = features.bmm(features_t) / (c * h * w)
|
|
return gram
|