import torch.nn as nn import functools def select_norm_layer(norm_type): if norm_type == "BN": return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == "IN": return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == "NONE": return lambda x: nn.Identity() else: raise NotImplemented(f'normalization layer {norm_type} is not found')