174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
import inspect
|
|
from omegaconf.dictconfig import DictConfig
|
|
from omegaconf import OmegaConf
|
|
from types import ModuleType
|
|
import warnings
|
|
|
|
class _Registry:
|
|
def __init__(self, name):
|
|
self._name = name
|
|
|
|
def get(self, key):
|
|
raise NotImplemented
|
|
|
|
def keys(self):
|
|
raise NotImplemented
|
|
|
|
def __len__(self):
|
|
len(self.keys())
|
|
|
|
def __contains__(self, key):
|
|
return self.get(key) is not None
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})"
|
|
|
|
def build_with(self, cfg, default_args=None):
|
|
"""Build a module from config dict.
|
|
Args:
|
|
cfg (dict): Config dict. It should at least contain the key "type".
|
|
default_args (dict, optional): Default initialization arguments.
|
|
Returns:
|
|
object: The constructed object.
|
|
"""
|
|
if isinstance(cfg, DictConfig):
|
|
cfg = OmegaConf.to_container(cfg)
|
|
if isinstance(cfg, dict):
|
|
if '_type' in cfg:
|
|
args = cfg.copy()
|
|
obj_type = args.pop('_type')
|
|
elif len(cfg) == 1:
|
|
obj_type, args = list(cfg.items())[0]
|
|
else:
|
|
raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}')
|
|
elif isinstance(cfg, str):
|
|
obj_type = cfg
|
|
args = dict()
|
|
else:
|
|
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
|
|
|
for k in args:
|
|
assert isinstance(k, str)
|
|
if k.startswith("_"):
|
|
warnings.warn(f"got param start with `_`: {k}, will remove it")
|
|
args.pop(k)
|
|
|
|
if not (isinstance(default_args, dict) or default_args is None):
|
|
raise TypeError('default_args must be a dict or None, '
|
|
f'but got {type(default_args)}')
|
|
|
|
if isinstance(obj_type, str):
|
|
obj_cls = self.get(obj_type)
|
|
if obj_cls is None:
|
|
raise KeyError(f'{obj_type} is not in the {self.name} registry')
|
|
elif inspect.isclass(obj_type):
|
|
obj_cls = obj_type
|
|
else:
|
|
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
|
|
|
|
if default_args is not None:
|
|
for name, value in default_args.items():
|
|
args.setdefault(name, value)
|
|
try:
|
|
obj = obj_cls(**args)
|
|
except TypeError as e:
|
|
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
|
|
return obj
|
|
|
|
|
|
class ModuleRegistry(_Registry):
|
|
def __init__(self, name, module, predefined_valid_list=None):
|
|
super().__init__(name)
|
|
|
|
assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}"
|
|
self._module = module
|
|
if predefined_valid_list is not None:
|
|
self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys())
|
|
else:
|
|
self._valid_set = set(self._module.__dict__.keys())
|
|
|
|
def keys(self):
|
|
return tuple(self._valid_set)
|
|
|
|
def get(self, key):
|
|
"""Get the registry record.
|
|
Args:
|
|
key (str): The class name in string format.
|
|
Returns:
|
|
class: The corresponding class.
|
|
"""
|
|
if key not in self._valid_set:
|
|
return None
|
|
return getattr(self._module, key)
|
|
|
|
|
|
class Registry(_Registry):
|
|
"""A registry to map strings to classes.
|
|
Args:
|
|
name (str): Registry name.
|
|
"""
|
|
|
|
def __init__(self, name):
|
|
super().__init__(name)
|
|
self._module_dict = dict()
|
|
|
|
def keys(self):
|
|
return tuple(self._module_dict.keys())
|
|
|
|
def get(self, key):
|
|
"""Get the registry record.
|
|
Args:
|
|
key (str): The class name in string format.
|
|
Returns:
|
|
class: The corresponding class.
|
|
"""
|
|
return self._module_dict.get(key, None)
|
|
|
|
def _register_module(self, module_class, module_name=None, force=False):
|
|
# if not inspect.isclass(module_class):
|
|
# raise TypeError('module must be a class, '
|
|
# f'but got {type(module_class)}')
|
|
|
|
if module_name is None:
|
|
module_name = module_class.__name__
|
|
if not force and module_name in self._module_dict:
|
|
raise KeyError(f'{module_name} is already registered '
|
|
f'in {self.name}')
|
|
self._module_dict[module_name] = module_class
|
|
|
|
def register_module(self, name=None, force=False, module=None):
|
|
"""Register a module.
|
|
A record will be added to `self._module_dict`, whose key is the class
|
|
name or the specified name, and value is the class itself.
|
|
It can be used as a decorator or a normal function.
|
|
Args:
|
|
name (str | None): The module name to be registered. If not
|
|
specified, the class name will be used.
|
|
force (bool, optional): Whether to override an existing class with
|
|
the same name. Default: False.
|
|
module (type): Module class to be registered.
|
|
"""
|
|
if not isinstance(force, bool):
|
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
|
|
|
# use it as a normal method: x.register_module(module=SomeClass)
|
|
if module is not None:
|
|
self._register_module(
|
|
module_class=module, module_name=name, force=force)
|
|
return module
|
|
|
|
# raise the error ahead of time
|
|
if not (name is None or isinstance(name, str)):
|
|
raise TypeError(f'name must be a str, but got {type(name)}')
|
|
|
|
# use it as a decorator: @x.register_module()
|
|
def _register(cls):
|
|
self._register_module(module_class=cls, module_name=name, force=force)
|
|
return cls
|
|
|
|
return _register
|