From 776fe40199e77ffdc1bec7d12482273f5a31149c Mon Sep 17 00:00:00 2001 From: Ray Wong Date: Sat, 26 Sep 2020 17:48:26 +0800 Subject: [PATCH] change a lot --- model/__init__.py | 12 +++++++----- model/normalization.py | 2 +- model/registry.py | 1 + 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/model/__init__.py b/model/__init__.py index c4c6cf5..20c1c47 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,8 +1,10 @@ -from model.registry import MODEL +from model.registry import MODEL, NORMALIZATION import model.GAN.CycleGAN +import model.GAN.MUNIT import model.GAN.TAFG -import model.GAN.UGATIT -import model.GAN.wrapper -import model.GAN.base import model.GAN.TSIT -import model.GAN.MUNIT \ No newline at end of file +import model.GAN.UGATIT +import model.GAN.base +import model.GAN.wrapper +import model.base.normalization + diff --git a/model/normalization.py b/model/normalization.py index a5cbcf2..9e3facf 100644 --- a/model/normalization.py +++ b/model/normalization.py @@ -6,7 +6,7 @@ import torch.nn as nn def select_norm_layer(norm_type): if norm_type == "BN": - return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + return functools.partial(nn.BatchNorm2d) elif norm_type == "IN": return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == "LN": diff --git a/model/registry.py b/model/registry.py index 6711b05..e0f7ea4 100644 --- a/model/registry.py +++ b/model/registry.py @@ -1,3 +1,4 @@ from util.registry import Registry MODEL = Registry("model") +NORMALIZATION = Registry("normalization") \ No newline at end of file