move norm select to top
This commit is contained in:
parent
42d6253a1d
commit
9e8e73c988
@ -1,18 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import functools
|
|
||||||
from model.registry import MODEL
|
from model.registry import MODEL
|
||||||
|
from model.normalization import select_norm_layer
|
||||||
|
|
||||||
def _select_norm_layer(norm_type):
|
|
||||||
if norm_type == "BN":
|
|
||||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
|
||||||
elif norm_type == "IN":
|
|
||||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
||||||
elif norm_type == "NONE":
|
|
||||||
return lambda x: nn.Identity()
|
|
||||||
else:
|
|
||||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
|
||||||
|
|
||||||
|
|
||||||
class GANImageBuffer(object):
|
class GANImageBuffer(object):
|
||||||
@ -77,7 +66,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if use_bias is None:
|
if use_bias is None:
|
||||||
# Only for IN, use bias since it does not have affine parameters.
|
# Only for IN, use bias since it does not have affine parameters.
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
norm_layer = _select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
models = [nn.Sequential(
|
models = [nn.Sequential(
|
||||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||||
norm_layer(num_channels),
|
norm_layer(num_channels),
|
||||||
@ -101,7 +90,7 @@ class ResGenerator(nn.Module):
|
|||||||
norm_type="IN", use_dropout=False):
|
norm_type="IN", use_dropout=False):
|
||||||
super(ResGenerator, self).__init__()
|
super(ResGenerator, self).__init__()
|
||||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||||
norm_layer = _select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
|
|
||||||
self.start_conv = nn.Sequential(
|
self.start_conv = nn.Sequential(
|
||||||
@ -157,7 +146,7 @@ class PatchDiscriminator(nn.Module):
|
|||||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||||
super(PatchDiscriminator, self).__init__()
|
super(PatchDiscriminator, self).__init__()
|
||||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||||
norm_layer = _select_norm_layer(norm_type)
|
norm_layer = select_norm_layer(norm_type)
|
||||||
use_bias = norm_type == "IN"
|
use_bias = norm_type == "IN"
|
||||||
|
|
||||||
kernel_size = 4
|
kernel_size = 4
|
||||||
|
|||||||
@ -0,0 +1,13 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
def select_norm_layer(norm_type):
|
||||||
|
if norm_type == "BN":
|
||||||
|
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||||
|
elif norm_type == "IN":
|
||||||
|
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||||
|
elif norm_type == "NONE":
|
||||||
|
return lambda x: nn.Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||||
Loading…
Reference in New Issue
Block a user