127 lines
5.1 KiB
Python
127 lines
5.1 KiB
Python
import torch.nn as nn
|
|
|
|
from model.registry import NORMALIZATION
|
|
|
|
_DO_NO_THING_FUNC = lambda x: x
|
|
|
|
|
|
def _use_bias_checker(norm_type):
|
|
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
|
|
|
|
|
|
def _normalization(norm, num_features, additional_kwargs=None):
|
|
if norm == "NONE":
|
|
return _DO_NO_THING_FUNC
|
|
|
|
if additional_kwargs is None:
|
|
additional_kwargs = {}
|
|
kwargs = dict(_type=norm, num_features=num_features)
|
|
kwargs.update(additional_kwargs)
|
|
return NORMALIZATION.build_with(kwargs)
|
|
|
|
|
|
def _activation(activation):
|
|
if activation == "NONE":
|
|
return _DO_NO_THING_FUNC
|
|
elif activation == "ReLU":
|
|
return nn.ReLU(inplace=True)
|
|
elif activation == "LeakyReLU":
|
|
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
elif activation == "Tanh":
|
|
return nn.Tanh()
|
|
else:
|
|
raise NotImplementedError(f"{activation} not valid")
|
|
|
|
|
|
class LinearBlock(nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
|
|
super().__init__()
|
|
|
|
self.norm_type = norm_type
|
|
self.activation_type = activation_type
|
|
|
|
bias = _use_bias_checker(norm_type) if bias is None else bias
|
|
self.linear = nn.Linear(in_features, out_features, bias)
|
|
|
|
self.normalization = _normalization(norm_type, out_features)
|
|
self.activation = _activation(activation_type)
|
|
|
|
def forward(self, x):
|
|
return self.activation(self.normalization(self.linear(x)))
|
|
|
|
|
|
class Conv2dBlock(nn.Module):
|
|
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
|
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
|
super().__init__()
|
|
self.norm_type = norm_type
|
|
self.activation_type = activation_type
|
|
|
|
# if caller not set bias, set bias automatically.
|
|
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
|
|
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
|
self.normalization = _normalization(norm_type, out_channels)
|
|
self.activation = _activation(activation_type)
|
|
|
|
def forward(self, x):
|
|
return self.activation(self.normalization(self.convolution(x)))
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
|
activation_type="ReLU", norm_type="IN", out_activation_type=None):
|
|
super().__init__()
|
|
self.norm_type = norm_type
|
|
|
|
if out_channels is None:
|
|
out_channels = num_channels
|
|
if out_activation_type is None:
|
|
out_activation_type = "NONE"
|
|
|
|
self.learn_skip_connection = num_channels != out_channels
|
|
|
|
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
norm_type=norm_type, activation_type=activation_type)
|
|
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
norm_type=norm_type, activation_type=out_activation_type)
|
|
|
|
if self.learn_skip_connection:
|
|
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
|
norm_type=norm_type, activation_type=out_activation_type)
|
|
|
|
def forward(self, x):
|
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
|
return self.conv2(self.conv1(x)) + res
|
|
|
|
|
|
class ReverseConv2dBlock(nn.Module):
|
|
def __init__(self, in_channels: int, out_channels: int,
|
|
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
|
super().__init__()
|
|
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
|
self.activation = _activation(activation_type)
|
|
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
|
|
|
def forward(self, x):
|
|
return self.convolution(self.activation(self.normalization(x)))
|
|
|
|
|
|
class ReverseResidualBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
|
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
|
super().__init__()
|
|
self.learn_skip_connection = in_channels != out_channels
|
|
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
|
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
|
if self.learn_skip_connection:
|
|
self.res_conv = ReverseConv2dBlock(
|
|
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
|
kernel_size=3, padding=1, padding_mode=padding_mode)
|
|
|
|
def forward(self, x):
|
|
res = x if not self.learn_skip_connection else self.res_conv(x)
|
|
return self.conv2(self.conv1(x)) + res
|