raycv/model/normalization.py
2020-08-28 08:16:07 +08:00

14 lines
467 B
Python

import torch.nn as nn
import functools
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "NONE":
return lambda x: nn.Identity()
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')