update cycle module

This commit is contained in:
Ray Wong 2020-08-22 20:57:03 +08:00
parent 58ed4524bf
commit 35ab7ecd51

View File

@ -118,16 +118,15 @@ class ResGenerator(nn.Module):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3,
stride=2, padding=1, bias=use_bias),
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.res_blocks = nn.ModuleList(
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
range(num_blocks)])
# up sampling
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
def forward(self, x):
x = self.encoder(self.start_conv(x))
for rb in self.res_blocks:
x = rb(x)
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))
@MODEL.register_module()
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = _select_norm_layer(norm_type)