122 lines
4.7 KiB
Python
122 lines
4.7 KiB
Python
import torch.nn as nn
|
|
|
|
from model.registry import NORMALIZATION
|
|
|
|
_DO_NO_THING_FUNC = lambda x: x
|
|
|
|
|
|
def _use_bias_checker(norm_type):
|
|
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
|
|
|
|
|
|
def _normalization(norm, num_features, additional_kwargs=None):
|
|
if norm == "NONE":
|
|
return _DO_NO_THING_FUNC
|
|
|
|
if additional_kwargs is None:
|
|
additional_kwargs = {}
|
|
kwargs = dict(_type=norm, num_features=num_features)
|
|
kwargs.update(additional_kwargs)
|
|
return NORMALIZATION.build_with(kwargs)
|
|
|
|
|
|
def _activation(activation, inplace=True):
|
|
if activation == "NONE":
|
|
return _DO_NO_THING_FUNC
|
|
elif activation == "ReLU":
|
|
return nn.ReLU(inplace=inplace)
|
|
elif activation == "LeakyReLU":
|
|
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
|
elif activation == "Tanh":
|
|
return nn.Tanh()
|
|
else:
|
|
raise NotImplementedError(f"{activation} not valid")
|
|
|
|
|
|
class LinearBlock(nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
|
|
super().__init__()
|
|
|
|
self.norm_type = norm_type
|
|
self.activation_type = activation_type
|
|
|
|
bias = _use_bias_checker(norm_type) if bias is None else bias
|
|
self.linear = nn.Linear(in_features, out_features, bias)
|
|
|
|
self.normalization = _normalization(norm_type, out_features)
|
|
self.activation = _activation(activation_type)
|
|
|
|
def forward(self, x):
|
|
return self.activation(self.normalization(self.linear(x)))
|
|
|
|
|
|
class Conv2dBlock(nn.Module):
|
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
|
|
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
|
|
super().__init__()
|
|
self.norm_type = norm_type
|
|
self.activation_type = activation_type
|
|
self.pre_activation = pre_activation
|
|
|
|
conv = nn.ConvTranspose2d if use_transpose_conv else nn.Conv2d
|
|
|
|
if pre_activation:
|
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
|
self.activation = _activation(activation_type, inplace=False)
|
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
|
else:
|
|
# if caller not set bias, set bias automatically.
|
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
|
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
|
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
|
self.activation = _activation(activation_type)
|
|
|
|
def forward(self, x):
|
|
if self.pre_activation:
|
|
return self.convolution(self.activation(self.normalization(x)))
|
|
return self.activation(self.normalization(self.convolution(x)))
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, in_channels,
|
|
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
|
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
|
|
"""
|
|
Residual Conv Block
|
|
:param in_channels:
|
|
:param out_channels:
|
|
:param padding_mode:
|
|
:param activation_type:
|
|
:param norm_type:
|
|
:param out_activation_type:
|
|
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
|
|
"""
|
|
super().__init__()
|
|
self.norm_type = norm_type
|
|
|
|
if out_channels is None:
|
|
out_channels = in_channels
|
|
if out_activation_type is None:
|
|
# if not specify `out_activation_type`, using default `out_activation_type`
|
|
# `out_activation_type` default mode:
|
|
# "NONE" for not full pre-activation
|
|
# `norm_type` for full pre-activation
|
|
out_activation_type = "NONE" if not pre_activation else norm_type
|
|
|
|
self.learn_skip_connection = in_channels != out_channels
|
|
|
|
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
|
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
|
|
padding_mode=padding_mode)
|
|
|
|
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
|
|
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
|
|
|
if self.learn_skip_connection:
|
|
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
|
|
|
def forward(self, x):
|
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
|
return self.conv2(self.conv1(x)) + res
|