From 611901cbdf78d13aa45716124175239cd712c895 Mon Sep 17 00:00:00 2001 From: budui Date: Tue, 13 Oct 2020 10:31:03 +0800 Subject: [PATCH] add ConvTranspose2d in Conv2d --- model/base/module.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/model/base/module.py b/model/base/module.py index a64e5f6..7674256 100644 --- a/model/base/module.py +++ b/model/base/module.py @@ -52,21 +52,23 @@ class LinearBlock(nn.Module): class Conv2dBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, bias=None, - activation_type="ReLU", norm_type="NONE", - additional_norm_kwargs=None, pre_activation=False, **conv_kwargs): + activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, + pre_activation=False, use_transpose_conv=False, **conv_kwargs): super().__init__() self.norm_type = norm_type self.activation_type = activation_type self.pre_activation = pre_activation + conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d + if pre_activation: self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) self.activation = _activation(activation_type, inplace=False) - self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + self.convolution = conv(in_channels, out_channels, **conv_kwargs) else: # if caller not set bias, set bias automatically. conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias - self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs) + self.convolution = conv(in_channels, out_channels, **conv_kwargs) self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) self.activation = _activation(activation_type)