diff --git a/model/weight_init.py b/model/weight_init.py index 8adc46f..6a64a4c 100644 --- a/model/weight_init.py +++ b/model/weight_init.py @@ -67,4 +67,6 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02): # only normal distribution applies. normal_init(m, 1.0, init_gain) + assert isinstance(module, nn.Module) module.apply(init_func) + diff --git a/util/registry.py b/util/registry.py index ecb0c29..6b59af1 100644 --- a/util/registry.py +++ b/util/registry.py @@ -119,9 +119,9 @@ class Registry(_Registry): return self._module_dict.get(key, None) def _register_module(self, module_class, module_name=None, force=False): - if not inspect.isclass(module_class): - raise TypeError('module must be a class, ' - f'but got {type(module_class)}') + # if not inspect.isclass(module_class): + # raise TypeError('module must be a class, ' + # f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__