26 lines
711 B
Python
26 lines
711 B
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from model import MODEL
|
|
|
|
|
|
@MODEL.register_module()
|
|
class MultiScaleDiscriminator(nn.Module):
|
|
def __init__(self, num_scale, discriminator_cfg):
|
|
super().__init__()
|
|
|
|
self.discriminator_list = nn.ModuleList([
|
|
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
|
])
|
|
|
|
@staticmethod
|
|
def down_sample(x):
|
|
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
|
|
|
def forward(self, x):
|
|
results = []
|
|
for discriminator in self.discriminator_list:
|
|
results.append(discriminator(x))
|
|
x = self.down_sample(x)
|
|
return results
|