import torch.nn as nn import torch.nn.functional as F from model import MODEL @MODEL.register_module() class MultiScaleDiscriminator(nn.Module): def __init__(self, num_scale, discriminator_cfg): super().__init__() self.discriminator_list = nn.ModuleList([ MODEL.build_with(discriminator_cfg) for _ in range(num_scale) ]) @staticmethod def down_sample(x): return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) def forward(self, x): results = [] for discriminator in self.discriminator_list: results.append(discriminator(x)) x = self.down_sample(x) return results