add ConvTranspose2d in Conv2d

This commit is contained in:
budui 2020-10-13 10:31:03 +08:00
parent a6ffab1445
commit 611901cbdf

View File

@ -52,21 +52,23 @@ class LinearBlock(nn.Module):
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE",
additional_norm_kwargs=None, pre_activation=False, **conv_kwargs):
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
self.pre_activation = pre_activation
conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)