62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
import math
|
|
|
|
import torch.nn as nn
|
|
|
|
from model.normalization import select_norm_layer
|
|
from model import MODEL
|
|
|
|
|
|
# based SPADE or pix2pixHD Discriminator
|
|
@MODEL.register_module("base-PatchDiscriminator")
|
|
class PatchDiscriminator(nn.Module):
|
|
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
|
need_intermediate_feature=False):
|
|
super().__init__()
|
|
self.need_intermediate_feature = need_intermediate_feature
|
|
|
|
kernel_size = 4
|
|
padding = math.ceil((kernel_size - 1.0) / 2)
|
|
norm_layer = select_norm_layer(norm_type)
|
|
use_bias = norm_type == "IN"
|
|
padding_mode = "zeros"
|
|
|
|
sequence = [nn.Sequential(
|
|
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
|
|
nn.LeakyReLU(0.2, False)
|
|
)]
|
|
multiple_now = 1
|
|
for i in range(1, num_conv):
|
|
multiple_prev = multiple_now
|
|
multiple_now = min(2 ** i, 2 ** 3)
|
|
stride = 1 if i == num_conv - 1 else 2
|
|
sequence.append(nn.Sequential(
|
|
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
|
|
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
|
|
norm_layer(base_channels * multiple_now),
|
|
nn.LeakyReLU(0.2, inplace=False),
|
|
))
|
|
multiple_now = min(2 ** num_conv, 8)
|
|
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
|
|
padding_mode=padding_mode))
|
|
self.conv_blocks = nn.ModuleList(sequence)
|
|
|
|
@staticmethod
|
|
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
|
|
bias=True, padding_mode: str = 'zeros'):
|
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
|
|
if not use_spectral:
|
|
return conv
|
|
return nn.utils.spectral_norm(conv)
|
|
|
|
def forward(self, x):
|
|
if self.need_intermediate_feature:
|
|
intermediate_feature = []
|
|
for layer in self.conv_blocks:
|
|
x = layer(x)
|
|
intermediate_feature.append(x)
|
|
return tuple(intermediate_feature)
|
|
else:
|
|
for layer in self.conv_blocks:
|
|
x = layer(x)
|
|
return x
|