import torch.nn as nn from model.registry import NORMALIZATION _DO_NO_THING_FUNC = lambda x: x def _use_bias_checker(norm_type): return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"] def _normalization(norm, num_features, additional_kwargs=None): if norm == "NONE": return _DO_NO_THING_FUNC if additional_kwargs is None: additional_kwargs = {} kwargs = dict(_type=norm, num_features=num_features) kwargs.update(additional_kwargs) return NORMALIZATION.build_with(kwargs) def _activation(activation, inplace=True): if activation == "NONE": return _DO_NO_THING_FUNC elif activation == "ReLU": return nn.ReLU(inplace=inplace) elif activation == "LeakyReLU": return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) elif activation == "Tanh": return nn.Tanh() else: raise NotImplementedError(f"{activation} not valid") class LinearBlock(nn.Module): def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"): super().__init__() self.norm_type = norm_type self.activation_type = activation_type bias = _use_bias_checker(norm_type) if bias is None else bias self.linear = nn.Linear(in_features, out_features, bias) self.normalization = _normalization(norm_type, out_features) self.activation = _activation(activation_type) def forward(self, x): return self.activation(self.normalization(self.linear(x))) class Conv2dBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, bias=None, activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, pre_activation=False, use_transpose_conv=False, **conv_kwargs): super().__init__() self.norm_type = norm_type self.activation_type = activation_type self.pre_activation = pre_activation if use_transpose_conv: # Only "zeros" padding mode is supported for ConvTranspose2d conv_kwargs["padding_mode"] = "zeros" conv = nn.ConvTranspose2d else: conv = nn.Conv2d if pre_activation: self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs) self.activation = _activation(activation_type, inplace=False) self.convolution = conv(in_channels, out_channels, **conv_kwargs) else: # if caller not set bias, set bias automatically. conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias self.convolution = conv(in_channels, out_channels, **conv_kwargs) self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs) self.activation = _activation(activation_type) def forward(self, x): if self.pre_activation: return self.convolution(self.activation(self.normalization(x))) return self.activation(self.normalization(self.convolution(x))) class ResidualBlock(nn.Module): def __init__(self, in_channels, padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False, out_channels=None, out_activation_type=None, additional_norm_kwargs=None): """ Residual Conv Block :param in_channels: :param out_channels: :param padding_mode: :param activation_type: :param norm_type: :param out_activation_type: :param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4 """ super().__init__() self.norm_type = norm_type if out_channels is None: out_channels = in_channels if out_activation_type is None: # if not specify `out_activation_type`, using default `out_activation_type` # `out_activation_type` default mode: # "NONE" for not full pre-activation # `norm_type` for full pre-activation out_activation_type = "NONE" if not pre_activation else norm_type self.learn_skip_connection = in_channels != out_channels conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type, additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation, padding_mode=padding_mode) self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param) self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param) if self.learn_skip_connection: conv_param['kernel_size'] = 1 conv_param['padding'] = 0 self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param) def forward(self, x): res = x if not self.learn_skip_connection else self.res_conv(x) return self.conv2(self.conv1(x)) + res