add ConvTranspose2d in Conv2d
This commit is contained in:
parent
a6ffab1445
commit
611901cbdf
@ -52,21 +52,23 @@ class LinearBlock(nn.Module):
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||
activation_type="ReLU", norm_type="NONE",
|
||||
additional_norm_kwargs=None, pre_activation=False, **conv_kwargs):
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
|
||||
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
self.pre_activation = pre_activation
|
||||
|
||||
conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d
|
||||
|
||||
if pre_activation:
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type, inplace=False)
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||
else:
|
||||
# if caller not set bias, set bias automatically.
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user