update cycle module

This commit is contained in:
Ray Wong 2020-08-22 20:57:03 +08:00
parent 58ed4524bf
commit 35ab7ecd51

View File

@ -118,17 +118,16 @@ class ResGenerator(nn.Module):
multiple = 2 ** i multiple = 2 ** i
submodules += [ submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, kernel_size=3, stride=2, padding=1, bias=use_bias),
stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2), norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
] ]
self.encoder = nn.Sequential(*submodules) self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels res_block_channels = num_down_sampling ** 2 * base_channels
self.res_blocks = nn.ModuleList( self.resnet_middle = nn.Sequential(
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in *[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
range(num_blocks)]) range(num_blocks)])
# up sampling # up sampling
submodules = [] submodules = []
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
def forward(self, x): def forward(self, x):
x = self.encoder(self.start_conv(x)) x = self.encoder(self.start_conv(x))
for rb in self.res_blocks: x = self.resnet_middle(x)
x = rb(x)
return self.end_conv(self.decoder(x)) return self.end_conv(self.decoder(x))
@MODEL.register_module() @MODEL.register_module()
class PatchDiscriminator(nn.Module): class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"): def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__() super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.' assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = _select_norm_layer(norm_type) norm_layer = _select_norm_layer(norm_type)