196 lines
8.2 KiB
Python
196 lines
8.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision.models.vgg as vgg
|
|
|
|
|
|
# Sequential(
|
|
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (1): ReLU(inplace=True)
|
|
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (3): ReLU(inplace=True)
|
|
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
|
|
|
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (6): ReLU(inplace=True)
|
|
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (8): ReLU(inplace=True)
|
|
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
|
|
|
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (11): ReLU(inplace=True)
|
|
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (13): ReLU(inplace=True)
|
|
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (15): ReLU(inplace=True)
|
|
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (17): ReLU(inplace=True)
|
|
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
|
|
|
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (20): ReLU(inplace=True)
|
|
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (22): ReLU(inplace=True)
|
|
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (24): ReLU(inplace=True)
|
|
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (26): ReLU(inplace=True)
|
|
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
|
|
|
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (29): ReLU(inplace=True)
|
|
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (31): ReLU(inplace=True)
|
|
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (33): ReLU(inplace=True)
|
|
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
# (35): ReLU(inplace=True)
|
|
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
|
# )
|
|
class PerceptualVGG(nn.Module):
|
|
"""VGG network used in calculating perceptual loss.
|
|
In this implementation, we allow users to choose whether use normalization
|
|
in the input feature and the type of vgg network. Note that the pretrained
|
|
path must fit the vgg type.
|
|
Args:
|
|
layer_name_list (list[str]): According to the index in this list,
|
|
forward function will return the corresponding features. This
|
|
list contains the name each layer in `vgg.feature`. An example
|
|
of this list is ['4', '10'].
|
|
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
|
norm_image_with_imagenet_param (bool): If True, normalize the input image.
|
|
Importantly, the input feature must in the range [0, 1].
|
|
Default: True.
|
|
"""
|
|
|
|
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
|
|
super(PerceptualVGG, self).__init__()
|
|
self.layer_name_list = layer_name_list
|
|
self.use_input_norm = norm_image_with_imagenet_param
|
|
|
|
# get vgg model and load pretrained vgg weight
|
|
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
|
_vgg = getattr(vgg, vgg_type)(pretrained=True)
|
|
num_layers = max(map(int, layer_name_list)) + 1
|
|
assert len(_vgg.features) >= num_layers
|
|
# only borrow layers that will be used from _vgg to avoid unused params
|
|
self.vgg_layers = _vgg.features[:num_layers]
|
|
|
|
if self.use_input_norm:
|
|
# the mean is for image with range [0, 1]
|
|
self.register_buffer(
|
|
'mean',
|
|
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
|
# the std is for image with range [-1, 1]
|
|
self.register_buffer(
|
|
'std',
|
|
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
|
|
|
for v in self.vgg_layers.parameters():
|
|
v.requies_grad = False
|
|
|
|
def forward(self, x):
|
|
"""Forward function.
|
|
Args:
|
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
|
Returns:
|
|
Tensor: Forward results.
|
|
"""
|
|
|
|
if self.use_input_norm:
|
|
x = (x - self.mean) / self.std
|
|
output = {}
|
|
|
|
for i, l in enumerate(self.vgg_layers):
|
|
x = l(x)
|
|
if str(i) in self.layer_name_list:
|
|
output[str(i)] = x.clone()
|
|
|
|
return output
|
|
|
|
|
|
class PerceptualLoss(nn.Module):
|
|
"""Perceptual loss with commonly used style loss.
|
|
Args:
|
|
layer_weights (dict): The weight for each layer of vgg feature.
|
|
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
|
|
5th, 10th and 18th feature layer will be extracted with weight 1.0
|
|
in calculating losses.
|
|
vgg_type (str): The type of vgg network used as feature extractor.
|
|
Default: 'vgg19'.
|
|
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
|
|
Default: True.
|
|
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
|
loss will be calculated.
|
|
Default: True.
|
|
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
|
|
Default: False.
|
|
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
|
|
this is different from the `use_input_norm` which norm the input in
|
|
in forward function of vgg according to the statistics of dataset.
|
|
Importantly, the input image must be in range [-1, 1].
|
|
"""
|
|
|
|
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
|
|
style_loss=False, norm_img=True, criterion='L1'):
|
|
super(PerceptualLoss, self).__init__()
|
|
self.norm_img = norm_img
|
|
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
|
|
self.perceptual_loss = perceptual_loss
|
|
self.style_loss = style_loss
|
|
self.layer_weights = layer_weights
|
|
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
|
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
|
|
|
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
|
|
|
|
def set_criterion(self, criterion: str):
|
|
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
|
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
|
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
|
return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t)
|
|
|
|
def forward(self, x, gt):
|
|
"""Forward function.
|
|
Args:
|
|
x (Tensor): Input tensor with shape (n, c, h, w).
|
|
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
|
Returns:
|
|
Tensor: Forward results.
|
|
"""
|
|
|
|
if self.norm_img:
|
|
x = (x + 1.) * 0.5
|
|
gt = (gt + 1.) * 0.5
|
|
# extract vgg features
|
|
x_features = self.vgg(x)
|
|
gt_features = self.vgg(gt.detach())
|
|
|
|
# calculate preceptual loss
|
|
if self.perceptual_loss:
|
|
percep_loss = 0
|
|
for k in x_features.keys():
|
|
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
|
return percep_loss
|
|
|
|
# calculate style loss
|
|
if self.style_loss:
|
|
style_loss = 0
|
|
for k in x_features.keys():
|
|
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
|
self.layer_weights[k]
|
|
return style_loss
|
|
|
|
def _gram_mat(self, x):
|
|
"""Calculate Gram matrix.
|
|
Args:
|
|
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
|
Returns:
|
|
torch.Tensor: Gram matrix.
|
|
"""
|
|
(n, c, h, w) = x.size()
|
|
features = x.view(n, c, w * h)
|
|
features_t = features.transpose(1, 2)
|
|
gram = features.bmm(features_t) / (c * h * w)
|
|
return gram
|