raycv/model/GAN/wrapper.py
2020-09-01 17:56:18 +08:00

26 lines
711 B
Python

import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg):
super().__init__()
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
@staticmethod
def down_sample(x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results