import math import torch.nn as nn from model.normalization import select_norm_layer from model import MODEL # based SPADE or pix2pixHD Discriminator @MODEL.register_module("base-PatchDiscriminator") class PatchDiscriminator(nn.Module): def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", need_intermediate_feature=False): super().__init__() self.need_intermediate_feature = need_intermediate_feature kernel_size = 4 padding = math.ceil((kernel_size - 1.0) / 2) norm_layer = select_norm_layer(norm_type) use_bias = norm_type == "IN" padding_mode = "zeros" sequence = [nn.Sequential( nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding), nn.LeakyReLU(0.2, False) )] multiple_now = 1 for i in range(1, num_conv): multiple_prev = multiple_now multiple_now = min(2 ** i, 2 ** 3) stride = 1 if i == num_conv - 1 else 2 sequence.append(nn.Sequential( self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode), norm_layer(base_channels * multiple_now), nn.LeakyReLU(0.2, inplace=False), )) multiple_now = min(2 ** num_conv, 8) sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode)) self.conv_blocks = nn.ModuleList(sequence) @staticmethod def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding, bias=True, padding_mode: str = 'zeros'): conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode) if not use_spectral: return conv return nn.utils.spectral_norm(conv) def forward(self, x): if self.need_intermediate_feature: intermediate_feature = [] for layer in self.conv_blocks: x = layer(x) intermediate_feature.append(x) return tuple(intermediate_feature) else: for layer in self.conv_blocks: x = layer(x) return x