update cycle module
This commit is contained in:
parent
58ed4524bf
commit
35ab7ecd51
@ -118,17 +118,16 @@ class ResGenerator(nn.Module):
|
|||||||
multiple = 2 ** i
|
multiple = 2 ** i
|
||||||
submodules += [
|
submodules += [
|
||||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||||
kernel_size=3,
|
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||||
stride=2, padding=1, bias=use_bias),
|
|
||||||
norm_layer(num_features=base_channels * multiple * 2),
|
norm_layer(num_features=base_channels * multiple * 2),
|
||||||
nn.ReLU(inplace=True)
|
nn.ReLU(inplace=True)
|
||||||
]
|
]
|
||||||
self.encoder = nn.Sequential(*submodules)
|
self.encoder = nn.Sequential(*submodules)
|
||||||
|
|
||||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||||
self.res_blocks = nn.ModuleList(
|
self.resnet_middle = nn.Sequential(
|
||||||
[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_dropout=use_dropout) for _ in
|
||||||
range(num_blocks)])
|
range(num_blocks)])
|
||||||
|
|
||||||
# up sampling
|
# up sampling
|
||||||
submodules = []
|
submodules = []
|
||||||
@ -149,14 +148,13 @@ class ResGenerator(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.encoder(self.start_conv(x))
|
x = self.encoder(self.start_conv(x))
|
||||||
for rb in self.res_blocks:
|
x = self.resnet_middle(x)
|
||||||
x = rb(x)
|
|
||||||
return self.end_conv(self.decoder(x))
|
return self.end_conv(self.decoder(x))
|
||||||
|
|
||||||
|
|
||||||
@MODEL.register_module()
|
@MODEL.register_module()
|
||||||
class PatchDiscriminator(nn.Module):
|
class PatchDiscriminator(nn.Module):
|
||||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="BN"):
|
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||||
super(PatchDiscriminator, self).__init__()
|
super(PatchDiscriminator, self).__init__()
|
||||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||||
norm_layer = _select_norm_layer(norm_type)
|
norm_layer = _select_norm_layer(norm_type)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user