update cycle module
This commit is contained in:
parent
58ed4524bf
commit
35ab7ecd51
@ -118,16 +118,15 @@ class ResGenerator(nn.Module):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3,
|
||||
stride=2, padding=1, bias=use_bias),
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
for rb in self.res_blocks:
|
||||
x = rb(x)
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||
norm_layer = _select_norm_layer(norm_type)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user