import inspect import warnings from types import ModuleType from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig class _Registry: def __init__(self, name): self._name = name def get(self, key): raise NotImplemented def keys(self): raise NotImplemented def __len__(self): len(self.keys()) def __contains__(self, key): return self.get(key) is not None @property def name(self): return self._name def __repr__(self): return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})" def build_with(self, cfg, default_args=None): """Build a module from config dict. Args: cfg (dict): Config dict. It should at least contain the key "type". default_args (dict, optional): Default initialization arguments. Returns: object: The constructed object. """ if isinstance(cfg, DictConfig): cfg = OmegaConf.to_container(cfg) if isinstance(cfg, dict): if '_type' in cfg: args = cfg.copy() obj_type = args.pop('_type') elif len(cfg) == 1: obj_type, args = list(cfg.items())[0] else: raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}') elif isinstance(cfg, str): obj_type = cfg args = dict() else: raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}') for invalid_key in [k for k in args.keys() if k.startswith("_")]: warnings.warn(f"got param start with `_`: {invalid_key}, will remove it") args.pop(invalid_key) if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') if isinstance(obj_type, str): obj_cls = self.get(obj_type) if obj_cls is None: raise KeyError(f'{obj_type} is not in the {self.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}') if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) try: obj = obj_cls(**args) except TypeError as e: raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e return obj class ModuleRegistry(_Registry): def __init__(self, name, module, predefined_valid_list=None): super().__init__(name) assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}" self._module = module if predefined_valid_list is not None: self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys()) else: self._valid_set = set(self._module.__dict__.keys()) def keys(self): return tuple(self._valid_set) def get(self, key): """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ if key not in self._valid_set: return None return getattr(self._module, key) class Registry(_Registry): """A registry to map strings to classes. Args: name (str): Registry name. """ def __init__(self, name): super().__init__(name) self._module_dict = dict() def keys(self): return tuple(self._module_dict.keys()) def get(self, key): """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ return self._module_dict.get(key, None) def _register_module(self, module_class, module_name=None, force=False): # if not inspect.isclass(module_class): # raise TypeError('module must be a class, ' # f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__ if not force and module_name in self._module_dict: if self._module_dict[module_name] == module_class: warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class') return raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}' f'so {module_class} can not be registered') self._module_dict[module_name] = module_class def register_module(self, name=None, force=False, module=None): """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Args: name (str | None): The module name to be registered. If not specified, the class name will be used. force (bool, optional): Whether to override an existing class with the same name. Default: False. module (type): Module class to be registered. """ if not isinstance(force, bool): raise TypeError(f'force must be a boolean, but got {type(force)}') # use it as a normal method: x.register_module(module=SomeClass) if module is not None: self._register_module( module_class=module, module_name=name, force=force) return module # raise the error ahead of time if not (name is None or isinstance(name, str)): raise TypeError(f'name must be a str, but got {type(name)}') # use it as a decorator: @x.register_module() def _register(cls): self._register_module(module_class=cls, module_name=name, force=force) return cls return _register