raycv/model/GAN/base.py
2020-09-01 17:56:18 +08:00

62 lines
2.4 KiB
Python

import math
import torch.nn as nn
from model.normalization import select_norm_layer
from model import MODEL
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("base-PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = math.ceil((kernel_size - 1.0) / 2)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
padding_mode = "zeros"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, False)
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(nn.Sequential(
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=False),
))
multiple_now = min(2 ** num_conv, 8)
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
padding_mode=padding_mode))
self.conv_blocks = nn.ModuleList(sequence)
@staticmethod
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
bias=True, padding_mode: str = 'zeros'):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
if not use_spectral:
return conv
return nn.utils.spectral_norm(conv)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.conv_blocks:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
for layer in self.conv_blocks:
x = layer(x)
return x