add ConvTranspose2d in Conv2d
This commit is contained in:
parent
a6ffab1445
commit
611901cbdf
@ -52,21 +52,23 @@ class LinearBlock(nn.Module):
|
|||||||
|
|
||||||
class Conv2dBlock(nn.Module):
|
class Conv2dBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||||
activation_type="ReLU", norm_type="NONE",
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
|
||||||
additional_norm_kwargs=None, pre_activation=False, **conv_kwargs):
|
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_type = norm_type
|
self.norm_type = norm_type
|
||||||
self.activation_type = activation_type
|
self.activation_type = activation_type
|
||||||
self.pre_activation = pre_activation
|
self.pre_activation = pre_activation
|
||||||
|
|
||||||
|
conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d
|
||||||
|
|
||||||
if pre_activation:
|
if pre_activation:
|
||||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||||
self.activation = _activation(activation_type, inplace=False)
|
self.activation = _activation(activation_type, inplace=False)
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||||
else:
|
else:
|
||||||
# if caller not set bias, set bias automatically.
|
# if caller not set bias, set bias automatically.
|
||||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||||
self.activation = _activation(activation_type)
|
self.activation = _activation(activation_type)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user