import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models.vgg as vgg # Sequential( # (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (1): ReLU(inplace=True) # (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (3): ReLU(inplace=True) # (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (6): ReLU(inplace=True) # (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (8): ReLU(inplace=True) # (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (11): ReLU(inplace=True) # (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (13): ReLU(inplace=True) # (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (15): ReLU(inplace=True) # (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (17): ReLU(inplace=True) # (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (20): ReLU(inplace=True) # (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (22): ReLU(inplace=True) # (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (24): ReLU(inplace=True) # (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (26): ReLU(inplace=True) # (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (29): ReLU(inplace=True) # (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (31): ReLU(inplace=True) # (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (33): ReLU(inplace=True) # (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (35): ReLU(inplace=True) # (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # ) class PerceptualVGG(nn.Module): """VGG network used in calculating perceptual loss. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): According to the index in this list, forward function will return the corresponding features. This list contains the name each layer in `vgg.feature`. An example of this list is ['4', '10']. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. norm_image_with_imagenet_param (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. """ def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True): super(PerceptualVGG, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = norm_image_with_imagenet_param # get vgg model and load pretrained vgg weight # remove _vgg from attributes to avoid `find_unused_parameters` bug _vgg = getattr(vgg, vgg_type)(pretrained=True) num_layers = max(map(int, layer_name_list)) + 1 assert len(_vgg.features) >= num_layers # only borrow layers that will be used from _vgg to avoid unused params self.vgg_layers = _vgg.features[:num_layers] if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( 'mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) # the std is for image with range [-1, 1] self.register_buffer( 'std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) for v in self.vgg_layers.parameters(): v.requies_grad = False def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.use_input_norm: x = (x - self.mean) / self.std output = {} for i, l in enumerate(self.vgg_layers): x = l(x) if str(i) in self.layer_name_list: output[str(i)] = x.clone() return output class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. Args: layer_weights (dict): The weight for each layer of vgg feature. Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the 5th, 10th and 18th feature layer will be extracted with weight 1.0 in calculating losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg. Default: True. perceptual_loss (bool): If `perceptual_loss == True`, the perceptual loss will be calculated. Default: True. style_loss (bool): If `style_loss == False`, the style loss will be calculated. Default: False. norm_img (bool): If True, the image will be normed to [0, 1]. Note that this is different from the `use_input_norm` which norm the input in in forward function of vgg according to the statistics of dataset. Importantly, the input image must be in range [-1, 1]. """ def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True, style_loss=False, norm_img=True, criterion='L1'): super(PerceptualLoss, self).__init__() self.norm_img = norm_img assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual" self.perceptual_loss = perceptual_loss self.style_loss = style_loss self.layer_weights = layer_weights self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, norm_image_with_imagenet_param=norm_image_with_imagenet_param) self.percep_criterion, self.style_criterion = self.set_criterion(criterion) def set_criterion(self, criterion: str): assert criterion in ["NL1", "NL2", "L1", "L2"] norm = F.instance_norm if criterion.startswith("N") else lambda x: x fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t) def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.norm_img: x = (x + 1.) * 0.5 gt = (gt + 1.) * 0.5 # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate preceptual loss if self.perceptual_loss: percep_loss = 0 for k in x_features.keys(): percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k] return percep_loss # calculate style loss if self.style_loss: style_loss = 0 for k in x_features.keys(): style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \ self.layer_weights[k] return style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ (n, c, h, w) = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram