import torch.nn as nn import torch.nn.functional as F from model import MODEL @MODEL.register_module() class MultiScaleDiscriminator(nn.Module): def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"): super().__init__() assert down_sample_method in ["avg", "bilinear"] self.down_sample_method = down_sample_method self.discriminator_list = nn.ModuleList([ MODEL.build_with(discriminator_cfg) for _ in range(num_scale) ]) def down_sample(self, x): if self.down_sample_method == "avg": return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) if self.down_sample_method == "bilinear": return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) def forward(self, x): results = [] for discriminator in self.discriminator_list: results.append(discriminator(x)) x = self.down_sample(x) return results