import torch.nn as nn from torch.nn import init def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0.0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.kaiming_uniform_( module.weight, a=a, mode=mode, nonlinearity=nonlinearity) else: nn.init.kaiming_normal_( module.weight, a=a, mode=mode, nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def xavier_init(module, gain=1.0, bias=0.0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def normal_init(module, mean=0.0, std=1.0, bias=0.0): nn.init.normal_(module.weight, mean, std) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def generation_init_weights(module, init_type='normal', init_gain=0.02): """Default initialization of network weights for image generation. By default, we use normal init, but xavier and kaiming might work better for some applications. Args: module (nn.Module): Module to be initialized. init_type (str): The name of an initialization method: normal | xavier | kaiming | orthogonal. init_gain (float): Scaling factor for normal, xavier and orthogonal. """ def init_func(m): """Initialization function. Args: m (nn.Module): Module to be initialized. """ classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': normal_init(m, 0.0, init_gain) elif init_type == 'xavier': xavier_init(m, gain=init_gain, distribution='normal') elif init_type == 'kaiming': kaiming_init(m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal') elif init_type == 'orthogonal': init.orthogonal_(m.weight, gain=init_gain) init.constant_(m.bias.data, 0.0) else: raise NotImplementedError( f"Initialization method '{init_type}' is not implemented") elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. if m.weight is not None: normal_init(m, 1.0, init_gain) assert isinstance(module, nn.Module) module.apply(init_func)