remove class check in registry; add type assert in weight_init
This commit is contained in:
parent
f7843de45d
commit
206d9343cd
@ -67,4 +67,6 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
|||||||
# only normal distribution applies.
|
# only normal distribution applies.
|
||||||
normal_init(m, 1.0, init_gain)
|
normal_init(m, 1.0, init_gain)
|
||||||
|
|
||||||
|
assert isinstance(module, nn.Module)
|
||||||
module.apply(init_func)
|
module.apply(init_func)
|
||||||
|
|
||||||
|
|||||||
@ -119,9 +119,9 @@ class Registry(_Registry):
|
|||||||
return self._module_dict.get(key, None)
|
return self._module_dict.get(key, None)
|
||||||
|
|
||||||
def _register_module(self, module_class, module_name=None, force=False):
|
def _register_module(self, module_class, module_name=None, force=False):
|
||||||
if not inspect.isclass(module_class):
|
# if not inspect.isclass(module_class):
|
||||||
raise TypeError('module must be a class, '
|
# raise TypeError('module must be a class, '
|
||||||
f'but got {type(module_class)}')
|
# f'but got {type(module_class)}')
|
||||||
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module_class.__name__
|
module_name = module_class.__name__
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user