move norm select to top

This commit is contained in:
budui 2020-08-28 08:16:07 +08:00
parent 42d6253a1d
commit 9e8e73c988
2 changed files with 17 additions and 15 deletions

View File

@ -1,18 +1,7 @@
import torch
import torch.nn as nn
import functools
from model.registry import MODEL
def _select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "NONE":
return lambda x: nn.Identity()
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
from model.normalization import select_norm_layer
class GANImageBuffer(object):
@ -77,7 +66,7 @@ class ResidualBlock(nn.Module):
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = _select_norm_layer(norm_type)
norm_layer = select_norm_layer(norm_type)
models = [nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
@ -101,7 +90,7 @@ class ResGenerator(nn.Module):
norm_type="IN", use_dropout=False):
super(ResGenerator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = _select_norm_layer(norm_type)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
@ -157,7 +146,7 @@ class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = _select_norm_layer(norm_type)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
kernel_size = 4

View File

@ -0,0 +1,13 @@
import torch.nn as nn
import functools
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "NONE":
return lambda x: nn.Identity()
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')