14 lines
467 B
Python
14 lines
467 B
Python
import torch.nn as nn
|
|
import functools
|
|
|
|
|
|
def select_norm_layer(norm_type):
|
|
if norm_type == "BN":
|
|
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
|
elif norm_type == "IN":
|
|
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
elif norm_type == "NONE":
|
|
return lambda x: nn.Identity()
|
|
else:
|
|
raise NotImplemented(f'normalization layer {norm_type} is not found')
|