raycv/model/image_translation/pix2pixHD.py
2020-10-22 22:42:01 +08:00

30 lines
1.0 KiB
Python

import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
super().__init__()
assert down_sample_method in ["avg", "bilinear"]
self.down_sample_method = down_sample_method
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
def down_sample(self, x):
if self.down_sample_method == "avg":
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
if self.down_sample_method == "bilinear":
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results