raycv/model/base/module.py
2020-10-25 20:46:34 +08:00

129 lines
4.9 KiB
Python

import torch.nn as nn
from model.registry import NORMALIZATION
_DO_NO_THING_FUNC = lambda x: x
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
def _normalization(norm, num_features, additional_kwargs=None):
if norm == "NONE":
return _DO_NO_THING_FUNC
if additional_kwargs is None:
additional_kwargs = {}
kwargs = dict(_type=norm, num_features=num_features)
kwargs.update(additional_kwargs)
return NORMALIZATION.build_with(kwargs)
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
self.pre_activation = pre_activation
if use_transpose_conv:
# Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
if self.pre_activation:
return self.convolution(self.activation(self.normalization(x)))
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
padding_mode=padding_mode)
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
conv_param['kernel_size'] = 1
conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res