73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
import torch.nn as nn
|
|
from torch.nn import init
|
|
|
|
|
|
def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0.0, distribution='normal'):
|
|
assert distribution in ['uniform', 'normal']
|
|
if distribution == 'uniform':
|
|
nn.init.kaiming_uniform_(
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
else:
|
|
nn.init.kaiming_normal_(
|
|
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def xavier_init(module, gain=1.0, bias=0.0, distribution='normal'):
|
|
assert distribution in ['uniform', 'normal']
|
|
if distribution == 'uniform':
|
|
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
else:
|
|
nn.init.xavier_normal_(module.weight, gain=gain)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def normal_init(module, mean=0.0, std=1.0, bias=0.0):
|
|
nn.init.normal_(module.weight, mean, std)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
|
"""Default initialization of network weights for image generation.
|
|
By default, we use normal init, but xavier and kaiming might work
|
|
better for some applications.
|
|
Args:
|
|
module (nn.Module): Module to be initialized.
|
|
init_type (str): The name of an initialization method:
|
|
normal | xavier | kaiming | orthogonal.
|
|
init_gain (float): Scaling factor for normal, xavier and
|
|
orthogonal.
|
|
"""
|
|
|
|
def init_func(m):
|
|
"""Initialization function.
|
|
Args:
|
|
m (nn.Module): Module to be initialized.
|
|
"""
|
|
classname = m.__class__.__name__
|
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1
|
|
or classname.find('Linear') != -1):
|
|
if init_type == 'normal':
|
|
normal_init(m, 0.0, init_gain)
|
|
elif init_type == 'xavier':
|
|
xavier_init(m, gain=init_gain, distribution='normal')
|
|
elif init_type == 'kaiming':
|
|
kaiming_init(m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal')
|
|
elif init_type == 'orthogonal':
|
|
init.orthogonal_(m.weight, gain=init_gain)
|
|
init.constant_(m.bias.data, 0.0)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Initialization method '{init_type}' is not implemented")
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
# BatchNorm Layer's weight is not a matrix;
|
|
# only normal distribution applies.
|
|
normal_init(m, 1.0, init_gain)
|
|
|
|
assert isinstance(module, nn.Module)
|
|
module.apply(init_func)
|
|
|