From 1d95f82a369ca345175139db51e06bdd02a86668 Mon Sep 17 00:00:00 2001 From: CoreLeader Date: Tue, 14 Dec 2021 11:30:00 +0800 Subject: [PATCH] Add encoder registry and decoder registry --- .../module/model/decoders/decoder_registry.py | 30 +++++++++++++ .../module/model/encoders/encoder_registry.py | 30 +++++++++++++ autogl/utils/universal_registry.py | 43 +++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 autogl/module/model/decoders/decoder_registry.py create mode 100644 autogl/module/model/encoders/encoder_registry.py create mode 100644 autogl/utils/universal_registry.py diff --git a/autogl/module/model/decoders/decoder_registry.py b/autogl/module/model/decoders/decoder_registry.py new file mode 100644 index 0000000..53e6ef0 --- /dev/null +++ b/autogl/module/model/decoders/decoder_registry.py @@ -0,0 +1,30 @@ +import typing as _typing +from autogl.utils import universal_registry +from . import base_decoder + + +class DecoderUniversalRegistry(universal_registry.UniversalRegistryBase): + @classmethod + def register_decoder(cls, name: str) -> _typing.Callable[ + [_typing.Type[base_decoder.BaseAutoDecoderMaintainer]], + _typing.Type[base_decoder.BaseAutoDecoderMaintainer] + ]: + def register_decoder( + _decoder: _typing.Type[base_decoder.BaseAutoDecoderMaintainer] + ) -> _typing.Type[base_decoder.BaseAutoDecoderMaintainer]: + if not issubclass(_decoder, base_decoder.BaseAutoDecoderMaintainer): + raise TypeError + elif name in cls: + raise ValueError + else: + cls[name] = _decoder + return _decoder + + return register_decoder + + @classmethod + def get_decoder(cls, name: str) -> _typing.Type[base_decoder.BaseAutoDecoderMaintainer]: + if name not in cls: + raise ValueError(f"Decoder with name \"{name}\" not exist") + else: + return cls[name] diff --git a/autogl/module/model/encoders/encoder_registry.py b/autogl/module/model/encoders/encoder_registry.py new file mode 100644 index 0000000..a088fb5 --- /dev/null +++ b/autogl/module/model/encoders/encoder_registry.py @@ -0,0 +1,30 @@ +import typing as _typing +from autogl.utils import universal_registry +from . import base_encoder + + +class EncoderUniversalRegistry(universal_registry.UniversalRegistryBase): + @classmethod + def register_encoder(cls, name: str) -> _typing.Callable[ + [_typing.Type[base_encoder.BaseAutoEncoderMaintainer]], + _typing.Type[base_encoder.BaseAutoEncoderMaintainer] + ]: + def register_encoder( + _encoder: _typing.Type[base_encoder.BaseAutoEncoderMaintainer] + ) -> _typing.Type[base_encoder.BaseAutoEncoderMaintainer]: + if not issubclass(_encoder, base_encoder.BaseAutoEncoderMaintainer): + raise TypeError + elif name in cls: + raise ValueError + else: + cls[name] = _encoder + return _encoder + + return register_encoder + + @classmethod + def get_encoder(cls, name: str) -> _typing.Type[base_encoder.BaseAutoEncoderMaintainer]: + if name not in cls: + raise ValueError(f"Encoder with name \"{name}\" not exist") + else: + return cls[name] diff --git a/autogl/utils/universal_registry.py b/autogl/utils/universal_registry.py new file mode 100644 index 0000000..8559be9 --- /dev/null +++ b/autogl/utils/universal_registry.py @@ -0,0 +1,43 @@ +import typing as _typing + + +class _UniversalRegistryMetaclass(type, _typing.MutableMapping[str, _typing.Any]): + def __getitem__(cls, k: str) -> _typing.Any: + return cls.__universal_registry[k] + + def __setitem__(cls, k: str, v: _typing.Any) -> None: + cls.__universal_registry[k] = v + + def __delitem__(cls, k: str) -> None: + del cls.__universal_registry[k] + + def __len__(cls) -> int: + return len(cls.__universal_registry) + + def __iter__(cls) -> _typing.Iterator[str]: + return iter(cls.__universal_registry) + + @property + def _universal_registry(cls) -> _typing.Mapping[str, _typing.Any]: + return cls.__universal_registry + + def __new__( + mcs, name: str, bases: _typing.Tuple[type, ...], + namespace: _typing.Dict[str, _typing.Any] + ): + return super(_UniversalRegistryMetaclass, mcs).__new__( + mcs, name, bases, namespace + ) + + def __init__( + cls, name: str, bases: _typing.Tuple[type, ...], + namespace: _typing.Dict[str, _typing.Any] + ): + super(_UniversalRegistryMetaclass, cls).__init__( + name, bases, namespace + ) + cls.__universal_registry: _typing.MutableMapping[str, _typing.Any] = {} + + +class UniversalRegistryBase(metaclass=_UniversalRegistryMetaclass): + ...