30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from model import MODEL
|
|
|
|
|
|
@MODEL.register_module()
|
|
class MultiScaleDiscriminator(nn.Module):
|
|
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
|
|
super().__init__()
|
|
assert down_sample_method in ["avg", "bilinear"]
|
|
self.down_sample_method = down_sample_method
|
|
|
|
self.discriminator_list = nn.ModuleList([
|
|
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
|
])
|
|
|
|
def down_sample(self, x):
|
|
if self.down_sample_method == "avg":
|
|
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
|
if self.down_sample_method == "bilinear":
|
|
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
|
|
|
|
def forward(self, x):
|
|
results = []
|
|
for discriminator in self.discriminator_list:
|
|
results.append(discriminator(x))
|
|
x = self.down_sample(x)
|
|
return results
|