|
- import os
- import logging as _logging
- import typing as _typing
-
- _backend_logger: _logging.Logger = _logging.getLogger("backend")
-
-
- class _BackendConfig(_typing.Mapping[str, _typing.Any]):
- def __init__(self, name: str, configurations: _typing.Mapping[str, _typing.Any] = ...):
- self.__name: str = name
- if (
- configurations not in (None, Ellipsis, ...) and
- isinstance(configurations, _typing.Mapping)
- ):
- self.__configurations: _typing.Mapping[str, _typing.Any] = configurations
- else:
- self.__configurations: _typing.Mapping[str, _typing.Any] = dict()
-
- def __str__(self) -> str:
- return self.__name
-
- def __getitem__(self, key: str):
- return self.__configurations[key]
-
- def __len__(self) -> int:
- return len(self.__configurations)
-
- def __iter__(self):
- return iter(self.__configurations)
-
-
- class _DGLConfig(_BackendConfig):
- def __init__(self):
- super(_DGLConfig, self).__init__("dgl")
-
-
- class _PyGConfig(_BackendConfig):
- def __init__(self):
- super(_PyGConfig, self).__init__("pyg")
-
-
- # class _BackendConfigGenerator:
- # ...
-
-
- def _generate_backend_config() -> _BackendConfig:
- def _generate_by_name(name: _typing.Optional[str] = ...) -> _BackendConfig:
- if name in (None, Ellipsis, ...) or not isinstance(name, str):
- try:
- import dgl
- return _DGLConfig()
- except ModuleNotFoundError:
- pass
- try:
- import torch_geometric
- return _PyGConfig()
- except ModuleNotFoundError:
- raise ModuleNotFoundError("Neither DGL nor PyTorch-Geometric exists")
- elif name.lower() not in ("dgl", "pyg"):
- __warning_message = " ".join((
- "The environment variable AUTOGL_BACKEND specified",
- "but is neither \"dgl\" nor \"pyg\",",
- "thus the environment variable is ignored",
- "and dependent backend for AutoGL is set automatically",
- ))
- _backend_logger.warning(__warning_message)
- return _generate_by_name()
- elif name.lower() == "dgl":
- try:
- import dgl
- return _DGLConfig()
- except ModuleNotFoundError:
- pass
- try:
- import torch_geometric
- __warning_message: str = " ".join((
- "The required backend DGL is not installed,",
- "use PyTorch-Geometric instead.",
- ))
- _backend_logger.warning(__warning_message)
- return _PyGConfig()
- except ModuleNotFoundError:
- raise ModuleNotFoundError("Neither DGL nor PyTorch-Geometric exists")
- elif name.lower() == "pyg":
- try:
- import torch_geometric
- return _PyGConfig()
- except ModuleNotFoundError:
- pass
- try:
- import dgl
- __warning_message: str = " ".join((
- "The required backend PyTorch-Geometric is not installed,",
- "use DGL instead.",
- ))
- _backend_logger.warning(__warning_message)
- return _DGLConfig()
- except ModuleNotFoundError:
- raise ModuleNotFoundError("Neither DGL nor PyTorch-Geometric exists")
- else:
- return _generate_by_name()
-
- if "AUTOGL_BACKEND" in os.environ:
- return _generate_by_name(os.getenv("AUTOGL_BACKEND"))
- else:
- return _generate_by_name()
-
-
- class _DependentBackendMetaclass(type):
- """
- Metaclass for ``DependentBackend``.
- To ensure the backend config is unique in diverse threads for multiprocessing runtime,
- the backend config is instantiated in the metaclass during interpretation phase.
- """
- def __new__(
- mcs, name: str, bases: _typing.Tuple[type, ...],
- namespace: _typing.Dict[str, _typing.Any]
- ):
- for base in bases:
- if isinstance(base, _DependentBackendMetaclass):
- strings = (
- f"{base} is instance of Metaclass {_DependentBackendMetaclass}",
- f"and MUST not be inherited/extended by <{name}> to construct"
- )
- raise TypeError(" ".join(strings))
- instance = super(_DependentBackendMetaclass, mcs).__new__(mcs, name, bases, namespace)
- return instance
-
- def __init__(
- cls, name: str, bases: _typing.Tuple[type, ...],
- namespace: _typing.Dict[str, _typing.Any]
- ):
- super(_DependentBackendMetaclass, cls).__init__(name, bases, namespace)
- cls._backend_config: _BackendConfig = _generate_backend_config()
- _backend_logger.info("Adopted backend: %s" % str(cls._backend_config))
-
-
- class DependentBackend(metaclass=_DependentBackendMetaclass):
- def __new__(cls, *args, **kwargs):
- raise RuntimeError(f"The class {DependentBackend} should not be instantiated")
-
- @classmethod
- def get_backend_name(cls) -> str:
- return str(cls._backend_config)
-
- @classmethod
- def is_dgl(cls) -> bool:
- return isinstance(cls._backend_config, _DGLConfig)
-
- @classmethod
- def is_pyg(cls) -> bool:
- return isinstance(cls._backend_config, _PyGConfig)
|