raycv/model/weight_init.py
2020-10-23 16:14:37 +08:00

74 lines
2.9 KiB
Python

import torch.nn as nn
from torch.nn import init
def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0.0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1.0, bias=0.0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0.0, std=1.0, bias=0.0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def generation_init_weights(module, init_type='normal', init_gain=0.02):
"""Default initialization of network weights for image generation.
By default, we use normal init, but xavier and kaiming might work
better for some applications.
Args:
module (nn.Module): Module to be initialized.
init_type (str): The name of an initialization method:
normal | xavier | kaiming | orthogonal.
init_gain (float): Scaling factor for normal, xavier and
orthogonal.
"""
def init_func(m):
"""Initialization function.
Args:
m (nn.Module): Module to be initialized.
"""
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
normal_init(m, 0.0, init_gain)
elif init_type == 'xavier':
xavier_init(m, gain=init_gain, distribution='normal')
elif init_type == 'kaiming':
kaiming_init(m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight, gain=init_gain)
init.constant_(m.bias.data, 0.0)
else:
raise NotImplementedError(
f"Initialization method '{init_type}' is not implemented")
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
if m.weight is not None:
normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func)