* add constant * add logger module * add registry and builder module * add fileio module * add requirements and setup.cfg * add config module and tests * add citest script Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8718998master
| @@ -0,0 +1,13 @@ | |||
| pip install -r requirements/runtime.txt | |||
| pip install -r requirements/tests.txt | |||
| # linter test | |||
| # use internal project for pre-commit due to the network problem | |||
| pre-commit run --all-files | |||
| if [ $? -ne 0 ]; then | |||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
| exit -1 | |||
| fi | |||
| PYTHONPATH=. python tests/run.py | |||
| @@ -0,0 +1,126 @@ | |||
| # Byte-compiled / optimized / DLL files | |||
| __pycache__/ | |||
| *.py[cod] | |||
| *$py.class | |||
| # C extensions | |||
| *.so | |||
| # Distribution / packaging | |||
| .Python | |||
| build/ | |||
| develop-eggs/ | |||
| dist/ | |||
| downloads/ | |||
| eggs/ | |||
| .eggs/ | |||
| lib/ | |||
| lib64/ | |||
| parts/ | |||
| sdist/ | |||
| var/ | |||
| wheels/ | |||
| *.egg-info/ | |||
| .installed.cfg | |||
| *.egg | |||
| /package | |||
| MANIFEST | |||
| # PyInstaller | |||
| # Usually these files are written by a python script from a template | |||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | |||
| *.manifest | |||
| *.spec | |||
| # Installer logs | |||
| pip-log.txt | |||
| pip-delete-this-directory.txt | |||
| # Unit test / coverage reports | |||
| htmlcov/ | |||
| .tox/ | |||
| .coverage | |||
| .coverage.* | |||
| .cache | |||
| nosetests.xml | |||
| coverage.xml | |||
| *.cover | |||
| .hypothesis/ | |||
| .pytest_cache/ | |||
| # Translations | |||
| *.mo | |||
| *.pot | |||
| # Django stuff: | |||
| *.log | |||
| local_settings.py | |||
| db.sqlite3 | |||
| # Flask stuff: | |||
| instance/ | |||
| .webassets-cache | |||
| # Scrapy stuff: | |||
| .scrapy | |||
| # Sphinx documentation | |||
| docs/_build/ | |||
| # PyBuilder | |||
| target/ | |||
| # Jupyter Notebook | |||
| .ipynb_checkpoints | |||
| # pyenv | |||
| .python-version | |||
| # celery beat schedule file | |||
| celerybeat-schedule | |||
| # SageMath parsed files | |||
| *.sage.py | |||
| # Environments | |||
| .env | |||
| .venv | |||
| env/ | |||
| venv/ | |||
| ENV/ | |||
| env.bak/ | |||
| venv.bak/ | |||
| # Spyder project settings | |||
| .spyderproject | |||
| .spyproject | |||
| # Rope project settings | |||
| .ropeproject | |||
| # mkdocs documentation | |||
| /site | |||
| # mypy | |||
| .mypy_cache/ | |||
| data | |||
| .vscode | |||
| .idea | |||
| # custom | |||
| *.pkl | |||
| *.pkl.json | |||
| *.log.json | |||
| *.whl | |||
| *.tar.gz | |||
| *.swp | |||
| *.log | |||
| *.tar.gz | |||
| source.sh | |||
| tensorboard.sh | |||
| .DS_Store | |||
| replace.sh | |||
| # Pytorch | |||
| *.pth | |||
| @@ -0,0 +1 @@ | |||
| This folder will host example configs for each model supported by maas_lib. | |||
| @@ -0,0 +1,7 @@ | |||
| { | |||
| "a": 1, | |||
| "b" : { | |||
| "c": [1,2,3], | |||
| "d" : "dd" | |||
| } | |||
| } | |||
| @@ -0,0 +1,2 @@ | |||
| a = 1 | |||
| b = dict(c=[1,2,3], d='dd') | |||
| @@ -0,0 +1,4 @@ | |||
| a: 1 | |||
| b: | |||
| c: [1,2,3] | |||
| d: dd | |||
| @@ -0,0 +1,5 @@ | |||
| model_dir: path/to/model | |||
| lr: 0.01 | |||
| optimizer: Adam | |||
| weight_decay: 1e-6 | |||
| save_checkpoint_epochs: 20 | |||
| @@ -0,0 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .version import __version__ | |||
| __all__ = ['__version__'] | |||
| @@ -0,0 +1 @@ | |||
| from .io import dump, dumps, load | |||
| @@ -0,0 +1,325 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import contextlib | |||
| import os | |||
| import tempfile | |||
| from abc import ABCMeta, abstractmethod | |||
| from pathlib import Path | |||
| from typing import Generator, Union | |||
| import requests | |||
| class Storage(metaclass=ABCMeta): | |||
| """Abstract class of storage. | |||
| All backends need to implement two apis: ``read()`` and ``read_text()``. | |||
| ``read()`` reads the file as a byte stream and ``read_text()`` reads | |||
| the file as texts. | |||
| """ | |||
| @abstractmethod | |||
| def read(self, filepath: str): | |||
| pass | |||
| @abstractmethod | |||
| def read_text(self, filepath: str): | |||
| pass | |||
| @abstractmethod | |||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||
| pass | |||
| @abstractmethod | |||
| def write_text(self, | |||
| obj: str, | |||
| filepath: Union[str, Path], | |||
| encoding: str = 'utf-8') -> None: | |||
| pass | |||
| class LocalStorage(Storage): | |||
| """Local hard disk storage""" | |||
| def read(self, filepath: Union[str, Path]) -> bytes: | |||
| """Read data from a given ``filepath`` with 'rb' mode. | |||
| Args: | |||
| filepath (str or Path): Path to read data. | |||
| Returns: | |||
| bytes: Expected bytes object. | |||
| """ | |||
| with open(filepath, 'rb') as f: | |||
| content = f.read() | |||
| return content | |||
| def read_text(self, | |||
| filepath: Union[str, Path], | |||
| encoding: str = 'utf-8') -> str: | |||
| """Read data from a given ``filepath`` with 'r' mode. | |||
| Args: | |||
| filepath (str or Path): Path to read data. | |||
| encoding (str): The encoding format used to open the ``filepath``. | |||
| Default: 'utf-8'. | |||
| Returns: | |||
| str: Expected text reading from ``filepath``. | |||
| """ | |||
| with open(filepath, 'r', encoding=encoding) as f: | |||
| value_buf = f.read() | |||
| return value_buf | |||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||
| """Write data to a given ``filepath`` with 'wb' mode. | |||
| Note: | |||
| ``put`` will create a directory if the directory of ``filepath`` | |||
| does not exist. | |||
| Args: | |||
| obj (bytes): Data to be written. | |||
| filepath (str or Path): Path to write data. | |||
| """ | |||
| dirname = os.path.dirname(filepath) | |||
| if dirname and not os.path.exists(dirname): | |||
| os.makedirs(dirname) | |||
| with open(filepath, 'wb') as f: | |||
| f.write(obj) | |||
| def write_text(self, | |||
| obj: str, | |||
| filepath: Union[str, Path], | |||
| encoding: str = 'utf-8') -> None: | |||
| """Write data to a given ``filepath`` with 'w' mode. | |||
| Note: | |||
| ``put_text`` will create a directory if the directory of | |||
| ``filepath`` does not exist. | |||
| Args: | |||
| obj (str): Data to be written. | |||
| filepath (str or Path): Path to write data. | |||
| encoding (str): The encoding format used to open the ``filepath``. | |||
| Default: 'utf-8'. | |||
| """ | |||
| dirname = os.path.dirname(filepath) | |||
| if dirname and not os.path.exists(dirname): | |||
| os.makedirs(dirname) | |||
| with open(filepath, 'w', encoding=encoding) as f: | |||
| f.write(obj) | |||
| @contextlib.contextmanager | |||
| def as_local_path( | |||
| self, | |||
| filepath: Union[str, | |||
| Path]) -> Generator[Union[str, Path], None, None]: | |||
| """Only for unified API and do nothing.""" | |||
| yield filepath | |||
| class HTTPStorage(Storage): | |||
| """HTTP and HTTPS storage.""" | |||
| def read(self, url): | |||
| r = requests.get(url) | |||
| r.raise_for_status() | |||
| return r.content | |||
| def read_text(self, url): | |||
| r = requests.get(url) | |||
| r.raise_for_status() | |||
| return r.text | |||
| @contextlib.contextmanager | |||
| def as_local_path( | |||
| self, filepath: str) -> Generator[Union[str, Path], None, None]: | |||
| """Download a file from ``filepath``. | |||
| ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It | |||
| can be called with ``with`` statement, and when exists from the | |||
| ``with`` statement, the temporary path will be released. | |||
| Args: | |||
| filepath (str): Download a file from ``filepath``. | |||
| Examples: | |||
| >>> storage = HTTPStorage() | |||
| >>> # After existing from the ``with`` clause, | |||
| >>> # the path will be removed | |||
| >>> with storage.get_local_path('http://path/to/file') as path: | |||
| ... # do something here | |||
| """ | |||
| try: | |||
| f = tempfile.NamedTemporaryFile(delete=False) | |||
| f.write(self.read(filepath)) | |||
| f.close() | |||
| yield f.name | |||
| finally: | |||
| os.remove(f.name) | |||
| def write(self, obj: bytes, url: Union[str, Path]) -> None: | |||
| raise NotImplementedError('write is not supported by HTTP Storage') | |||
| def write_text(self, | |||
| obj: str, | |||
| url: Union[str, Path], | |||
| encoding: str = 'utf-8') -> None: | |||
| raise NotImplementedError( | |||
| 'write_text is not supported by HTTP Storage') | |||
| class OSSStorage(Storage): | |||
| """OSS storage.""" | |||
| def __init__(self, oss_config_file=None): | |||
| # read from config file or env var | |||
| raise NotImplementedError( | |||
| 'OSSStorage.__init__ to be implemented in the future') | |||
| def read(self, filepath): | |||
| raise NotImplementedError( | |||
| 'OSSStorage.read to be implemented in the future') | |||
| def read_text(self, filepath, encoding='utf-8'): | |||
| raise NotImplementedError( | |||
| 'OSSStorage.read_text to be implemented in the future') | |||
| @contextlib.contextmanager | |||
| def as_local_path( | |||
| self, filepath: str) -> Generator[Union[str, Path], None, None]: | |||
| """Download a file from ``filepath``. | |||
| ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It | |||
| can be called with ``with`` statement, and when exists from the | |||
| ``with`` statement, the temporary path will be released. | |||
| Args: | |||
| filepath (str): Download a file from ``filepath``. | |||
| Examples: | |||
| >>> storage = OSSStorage() | |||
| >>> # After existing from the ``with`` clause, | |||
| >>> # the path will be removed | |||
| >>> with storage.get_local_path('http://path/to/file') as path: | |||
| ... # do something here | |||
| """ | |||
| try: | |||
| f = tempfile.NamedTemporaryFile(delete=False) | |||
| f.write(self.read(filepath)) | |||
| f.close() | |||
| yield f.name | |||
| finally: | |||
| os.remove(f.name) | |||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||
| raise NotImplementedError( | |||
| 'OSSStorage.write to be implemented in the future') | |||
| def write_text(self, | |||
| obj: str, | |||
| filepath: Union[str, Path], | |||
| encoding: str = 'utf-8') -> None: | |||
| raise NotImplementedError( | |||
| 'OSSStorage.write_text to be implemented in the future') | |||
| G_STORAGES = {} | |||
| class File(object): | |||
| _prefix_to_storage: dict = { | |||
| 'oss': OSSStorage, | |||
| 'http': HTTPStorage, | |||
| 'https': HTTPStorage, | |||
| 'local': LocalStorage, | |||
| } | |||
| @staticmethod | |||
| def _get_storage(uri): | |||
| assert isinstance(uri, | |||
| str), f'uri should be str type, buf got {type(uri)}' | |||
| if '://' not in uri: | |||
| # local path | |||
| storage_type = 'local' | |||
| else: | |||
| prefix, _ = uri.split('://') | |||
| storage_type = prefix | |||
| assert storage_type in File._prefix_to_storage, \ | |||
| f'Unsupported uri {uri}, valid prefixs: '\ | |||
| f'{list(File._prefix_to_storage.keys())}' | |||
| if storage_type not in G_STORAGES: | |||
| G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() | |||
| return G_STORAGES[storage_type] | |||
| @staticmethod | |||
| def read(uri: str) -> bytes: | |||
| """Read data from a given ``filepath`` with 'rb' mode. | |||
| Args: | |||
| filepath (str or Path): Path to read data. | |||
| Returns: | |||
| bytes: Expected bytes object. | |||
| """ | |||
| storage = File._get_storage(uri) | |||
| return storage.read(uri) | |||
| @staticmethod | |||
| def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: | |||
| """Read data from a given ``filepath`` with 'r' mode. | |||
| Args: | |||
| filepath (str or Path): Path to read data. | |||
| encoding (str): The encoding format used to open the ``filepath``. | |||
| Default: 'utf-8'. | |||
| Returns: | |||
| str: Expected text reading from ``filepath``. | |||
| """ | |||
| storage = File._get_storage(uri) | |||
| return storage.read_text(uri) | |||
| @staticmethod | |||
| def write(obj: bytes, uri: Union[str, Path]) -> None: | |||
| """Write data to a given ``filepath`` with 'wb' mode. | |||
| Note: | |||
| ``put`` will create a directory if the directory of ``filepath`` | |||
| does not exist. | |||
| Args: | |||
| obj (bytes): Data to be written. | |||
| filepath (str or Path): Path to write data. | |||
| """ | |||
| storage = File._get_storage(uri) | |||
| return storage.write(obj, uri) | |||
| @staticmethod | |||
| def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: | |||
| """Write data to a given ``filepath`` with 'w' mode. | |||
| Note: | |||
| ``put_text`` will create a directory if the directory of | |||
| ``filepath`` does not exist. | |||
| Args: | |||
| obj (str): Data to be written. | |||
| filepath (str or Path): Path to write data. | |||
| encoding (str): The encoding format used to open the ``filepath``. | |||
| Default: 'utf-8'. | |||
| """ | |||
| storage = File._get_storage(uri) | |||
| return storage.write_text(obj, uri) | |||
| @contextlib.contextmanager | |||
| def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: | |||
| """Only for unified API and do nothing.""" | |||
| storage = File._get_storage(uri) | |||
| with storage.as_local_path(uri) as local_path: | |||
| yield local_path | |||
| @@ -0,0 +1,3 @@ | |||
| from .base import FormatHandler | |||
| from .json import JsonHandler | |||
| from .yaml import YamlHandler | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABCMeta, abstractmethod | |||
| class FormatHandler(metaclass=ABCMeta): | |||
| # if `text_format` is True, file | |||
| # should use text mode otherwise binary mode | |||
| text_mode = True | |||
| @abstractmethod | |||
| def load(self, file, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def dump(self, obj, file, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def dumps(self, obj, **kwargs): | |||
| pass | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import json | |||
| import numpy as np | |||
| from .base import FormatHandler | |||
| def set_default(obj): | |||
| """Set default json values for non-serializable values. | |||
| It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. | |||
| It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, | |||
| etc.) into plain numbers of plain python built-in types. | |||
| """ | |||
| if isinstance(obj, (set, range)): | |||
| return list(obj) | |||
| elif isinstance(obj, np.ndarray): | |||
| return obj.tolist() | |||
| elif isinstance(obj, np.generic): | |||
| return obj.item() | |||
| raise TypeError(f'{type(obj)} is unsupported for json dump') | |||
| class JsonHandler(FormatHandler): | |||
| def load(self, file): | |||
| return json.load(file) | |||
| def dump(self, obj, file, **kwargs): | |||
| kwargs.setdefault('default', set_default) | |||
| json.dump(obj, file, **kwargs) | |||
| def dumps(self, obj, **kwargs): | |||
| kwargs.setdefault('default', set_default) | |||
| return json.dumps(obj, **kwargs) | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import yaml | |||
| try: | |||
| from yaml import CDumper as Dumper | |||
| from yaml import CLoader as Loader | |||
| except ImportError: | |||
| from yaml import Loader, Dumper # type: ignore | |||
| from .base import FormatHandler # isort:skip | |||
| class YamlHandler(FormatHandler): | |||
| def load(self, file, **kwargs): | |||
| kwargs.setdefault('Loader', Loader) | |||
| return yaml.load(file, **kwargs) | |||
| def dump(self, obj, file, **kwargs): | |||
| kwargs.setdefault('Dumper', Dumper) | |||
| yaml.dump(obj, file, **kwargs) | |||
| def dumps(self, obj, **kwargs): | |||
| kwargs.setdefault('Dumper', Dumper) | |||
| return yaml.dump(obj, **kwargs) | |||
| @@ -0,0 +1,127 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| from io import BytesIO, StringIO | |||
| from pathlib import Path | |||
| from .file import File | |||
| from .format import JsonHandler, YamlHandler | |||
| format_handlers = { | |||
| 'json': JsonHandler(), | |||
| 'yaml': YamlHandler(), | |||
| 'yml': YamlHandler(), | |||
| } | |||
| def load(file, file_format=None, **kwargs): | |||
| """Load data from json/yaml/pickle files. | |||
| This method provides a unified api for loading data from serialized files. | |||
| Args: | |||
| file (str or :obj:`Path` or file-like object): Filename or a file-like | |||
| object. | |||
| file_format (str, optional): If not specified, the file format will be | |||
| inferred from the file extension, otherwise use the specified one. | |||
| Currently supported formats include "json", "yaml/yml". | |||
| Examples: | |||
| >>> load('/path/of/your/file') # file is storaged in disk | |||
| >>> load('https://path/of/your/file') # file is storaged in Internet | |||
| >>> load('oss://path/of/your/file') # file is storaged in petrel | |||
| Returns: | |||
| The content from the file. | |||
| """ | |||
| if isinstance(file, Path): | |||
| file = str(file) | |||
| if file_format is None and isinstance(file, str): | |||
| file_format = file.split('.')[-1] | |||
| if file_format not in format_handlers: | |||
| raise TypeError(f'Unsupported format: {file_format}') | |||
| handler = format_handlers[file_format] | |||
| if isinstance(file, str): | |||
| if handler.text_mode: | |||
| with StringIO(File.read_text(file)) as f: | |||
| obj = handler.load(f, **kwargs) | |||
| else: | |||
| with BytesIO(File.read(file)) as f: | |||
| obj = handler.load(f, **kwargs) | |||
| elif hasattr(file, 'read'): | |||
| obj = handler.load(file, **kwargs) | |||
| else: | |||
| raise TypeError('"file" must be a filepath str or a file-object') | |||
| return obj | |||
| def dump(obj, file=None, file_format=None, **kwargs): | |||
| """Dump data to json/yaml strings or files. | |||
| This method provides a unified api for dumping data as strings or to files. | |||
| Args: | |||
| obj (any): The python object to be dumped. | |||
| file (str or :obj:`Path` or file-like object, optional): If not | |||
| specified, then the object is dumped to a str, otherwise to a file | |||
| specified by the filename or file-like object. | |||
| file_format (str, optional): Same as :func:`load`. | |||
| Examples: | |||
| >>> dump('hello world', '/path/of/your/file') # disk | |||
| >>> dump('hello world', 'oss://path/of/your/file') # oss | |||
| Returns: | |||
| bool: True for success, False otherwise. | |||
| """ | |||
| if isinstance(file, Path): | |||
| file = str(file) | |||
| if file_format is None: | |||
| if isinstance(file, str): | |||
| file_format = file.split('.')[-1] | |||
| elif file is None: | |||
| raise ValueError( | |||
| 'file_format must be specified since file is None') | |||
| if file_format not in format_handlers: | |||
| raise TypeError(f'Unsupported format: {file_format}') | |||
| handler = format_handlers[file_format] | |||
| if file is None: | |||
| return handler.dump_to_str(obj, **kwargs) | |||
| elif isinstance(file, str): | |||
| if handler.text_mode: | |||
| with StringIO() as f: | |||
| handler.dump(obj, f, **kwargs) | |||
| File.write_text(f.getvalue(), file) | |||
| else: | |||
| with BytesIO() as f: | |||
| handler.dump(obj, f, **kwargs) | |||
| File.write(f.getvalue(), file) | |||
| elif hasattr(file, 'write'): | |||
| handler.dump(obj, file, **kwargs) | |||
| else: | |||
| raise TypeError('"file" must be a filename str or a file-object') | |||
| def dumps(obj, format, **kwargs): | |||
| """Dump data to json/yaml strings or files. | |||
| This method provides a unified api for dumping data as strings or to files. | |||
| Args: | |||
| obj (any): The python object to be dumped. | |||
| format (str, optional): Same as file_format :func:`load`. | |||
| Examples: | |||
| >>> dumps('hello world', 'json') # disk | |||
| >>> dumps('hello world', 'yaml') # oss | |||
| Returns: | |||
| bool: True for success, False otherwise. | |||
| """ | |||
| if format not in format_handlers: | |||
| raise TypeError(f'Unsupported format: {format}') | |||
| handler = format_handlers[format] | |||
| return handler.dumps(obj, **kwargs) | |||
| @@ -0,0 +1,472 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import ast | |||
| import copy | |||
| import os | |||
| import os.path as osp | |||
| import platform | |||
| import shutil | |||
| import sys | |||
| import tempfile | |||
| import types | |||
| import uuid | |||
| from importlib import import_module | |||
| from pathlib import Path | |||
| from typing import Dict | |||
| import addict | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from maas_lib.utils.logger import get_logger | |||
| from maas_lib.utils.pymod import (import_modules, import_modules_from_file, | |||
| validate_py_syntax) | |||
| if platform.system() == 'Windows': | |||
| import regex as re # type: ignore | |||
| else: | |||
| import re # type: ignore | |||
| logger = get_logger() | |||
| BASE_KEY = '_base_' | |||
| DELETE_KEY = '_delete_' | |||
| DEPRECATION_KEY = '_deprecation_' | |||
| RESERVED_KEYS = ['filename', 'text', 'pretty_text'] | |||
| class ConfigDict(addict.Dict): | |||
| """ Dict which support get value through getattr | |||
| Examples: | |||
| >>> cdict = ConfigDict({'a':1232}) | |||
| >>> print(cdict.a) | |||
| 1232 | |||
| """ | |||
| def __missing__(self, name): | |||
| raise KeyError(name) | |||
| def __getattr__(self, name): | |||
| try: | |||
| value = super(ConfigDict, self).__getattr__(name) | |||
| except KeyError: | |||
| ex = AttributeError(f"'{self.__class__.__name__}' object has no " | |||
| f"attribute '{name}'") | |||
| except Exception as e: | |||
| ex = e | |||
| else: | |||
| return value | |||
| raise ex | |||
| class Config: | |||
| """A facility for config and config files. | |||
| It supports common file formats as configs: python/json/yaml. The interface | |||
| is the same as a dict object and also allows access config values as | |||
| attributes. | |||
| Example: | |||
| >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd'))) | |||
| >>> cfg.a | |||
| 1 | |||
| >>> cfg.b | |||
| {'c': [1, 2, 3], 'd': 'dd'} | |||
| >>> cfg.b.d | |||
| 'dd' | |||
| >>> cfg = Config.from_file('configs/examples/config.json') | |||
| >>> cfg.filename | |||
| 'configs/examples/config.json' | |||
| >>> cfg.b | |||
| {'c': [1, 2, 3], 'd': 'dd'} | |||
| >>> cfg = Config.from_file('configs/examples/config.py') | |||
| >>> cfg.filename | |||
| "configs/examples/config.py" | |||
| >>> cfg = Config.from_file('configs/examples/config.yaml') | |||
| >>> cfg.filename | |||
| "configs/examples/config.yaml" | |||
| """ | |||
| @staticmethod | |||
| def _file2dict(filename): | |||
| filename = osp.abspath(osp.expanduser(filename)) | |||
| if not osp.exists(filename): | |||
| raise ValueError(f'File does not exists {filename}') | |||
| fileExtname = osp.splitext(filename)[1] | |||
| if fileExtname not in ['.py', '.json', '.yaml', '.yml']: | |||
| raise IOError('Only py/yml/yaml/json type are supported now!') | |||
| with tempfile.TemporaryDirectory() as tmp_cfg_dir: | |||
| tmp_cfg_file = tempfile.NamedTemporaryFile( | |||
| dir=tmp_cfg_dir, suffix=fileExtname) | |||
| if platform.system() == 'Windows': | |||
| tmp_cfg_file.close() | |||
| tmp_cfg_name = osp.basename(tmp_cfg_file.name) | |||
| shutil.copyfile(filename, tmp_cfg_file.name) | |||
| if filename.endswith('.py'): | |||
| module_nanme, mod = import_modules_from_file( | |||
| osp.join(tmp_cfg_dir, tmp_cfg_name)) | |||
| cfg_dict = {} | |||
| for name, value in mod.__dict__.items(): | |||
| if not name.startswith('__') and \ | |||
| not isinstance(value, types.ModuleType) and \ | |||
| not isinstance(value, types.FunctionType): | |||
| cfg_dict[name] = value | |||
| # delete imported module | |||
| del sys.modules[module_nanme] | |||
| elif filename.endswith(('.yml', '.yaml', '.json')): | |||
| from maas_lib.fileio import load | |||
| cfg_dict = load(tmp_cfg_file.name) | |||
| # close temp file | |||
| tmp_cfg_file.close() | |||
| cfg_text = filename + '\n' | |||
| with open(filename, 'r', encoding='utf-8') as f: | |||
| # Setting encoding explicitly to resolve coding issue on windows | |||
| cfg_text += f.read() | |||
| return cfg_dict, cfg_text | |||
| @staticmethod | |||
| def from_file(filename): | |||
| if isinstance(filename, Path): | |||
| filename = str(filename) | |||
| cfg_dict, cfg_text = Config._file2dict(filename) | |||
| return Config(cfg_dict, cfg_text=cfg_text, filename=filename) | |||
| @staticmethod | |||
| def from_string(cfg_str, file_format): | |||
| """Generate config from config str. | |||
| Args: | |||
| cfg_str (str): Config str. | |||
| file_format (str): Config file format corresponding to the | |||
| config str. Only py/yml/yaml/json type are supported now! | |||
| Returns: | |||
| :obj:`Config`: Config obj. | |||
| """ | |||
| if file_format not in ['.py', '.json', '.yaml', '.yml']: | |||
| raise IOError('Only py/yml/yaml/json type are supported now!') | |||
| if file_format != '.py' and 'dict(' in cfg_str: | |||
| # check if users specify a wrong suffix for python | |||
| logger.warning( | |||
| 'Please check "file_format", the file format may be .py') | |||
| with tempfile.NamedTemporaryFile( | |||
| 'w', encoding='utf-8', suffix=file_format, | |||
| delete=False) as temp_file: | |||
| temp_file.write(cfg_str) | |||
| # on windows, previous implementation cause error | |||
| # see PR 1077 for details | |||
| cfg = Config.from_file(temp_file.name) | |||
| os.remove(temp_file.name) | |||
| return cfg | |||
| def __init__(self, cfg_dict=None, cfg_text=None, filename=None): | |||
| if cfg_dict is None: | |||
| cfg_dict = dict() | |||
| elif not isinstance(cfg_dict, dict): | |||
| raise TypeError('cfg_dict must be a dict, but ' | |||
| f'got {type(cfg_dict)}') | |||
| for key in cfg_dict: | |||
| if key in RESERVED_KEYS: | |||
| raise KeyError(f'{key} is reserved for config file') | |||
| if isinstance(filename, Path): | |||
| filename = str(filename) | |||
| super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) | |||
| super(Config, self).__setattr__('_filename', filename) | |||
| if cfg_text: | |||
| text = cfg_text | |||
| elif filename: | |||
| with open(filename, 'r') as f: | |||
| text = f.read() | |||
| else: | |||
| text = '' | |||
| super(Config, self).__setattr__('_text', text) | |||
| @property | |||
| def filename(self): | |||
| return self._filename | |||
| @property | |||
| def text(self): | |||
| return self._text | |||
| @property | |||
| def pretty_text(self): | |||
| indent = 4 | |||
| def _indent(s_, num_spaces): | |||
| s = s_.split('\n') | |||
| if len(s) == 1: | |||
| return s_ | |||
| first = s.pop(0) | |||
| s = [(num_spaces * ' ') + line for line in s] | |||
| s = '\n'.join(s) | |||
| s = first + '\n' + s | |||
| return s | |||
| def _format_basic_types(k, v, use_mapping=False): | |||
| if isinstance(v, str): | |||
| v_str = f"'{v}'" | |||
| else: | |||
| v_str = str(v) | |||
| if use_mapping: | |||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||
| attr_str = f'{k_str}: {v_str}' | |||
| else: | |||
| attr_str = f'{str(k)}={v_str}' | |||
| attr_str = _indent(attr_str, indent) | |||
| return attr_str | |||
| def _format_list(k, v, use_mapping=False): | |||
| # check if all items in the list are dict | |||
| if all(isinstance(_, dict) for _ in v): | |||
| v_str = '[\n' | |||
| v_str += '\n'.join( | |||
| f'dict({_indent(_format_dict(v_), indent)}),' | |||
| for v_ in v).rstrip(',') | |||
| if use_mapping: | |||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||
| attr_str = f'{k_str}: {v_str}' | |||
| else: | |||
| attr_str = f'{str(k)}={v_str}' | |||
| attr_str = _indent(attr_str, indent) + ']' | |||
| else: | |||
| attr_str = _format_basic_types(k, v, use_mapping) | |||
| return attr_str | |||
| def _contain_invalid_identifier(dict_str): | |||
| contain_invalid_identifier = False | |||
| for key_name in dict_str: | |||
| contain_invalid_identifier |= \ | |||
| (not str(key_name).isidentifier()) | |||
| return contain_invalid_identifier | |||
| def _format_dict(input_dict, outest_level=False): | |||
| r = '' | |||
| s = [] | |||
| use_mapping = _contain_invalid_identifier(input_dict) | |||
| if use_mapping: | |||
| r += '{' | |||
| for idx, (k, v) in enumerate(input_dict.items()): | |||
| is_last = idx >= len(input_dict) - 1 | |||
| end = '' if outest_level or is_last else ',' | |||
| if isinstance(v, dict): | |||
| v_str = '\n' + _format_dict(v) | |||
| if use_mapping: | |||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||
| attr_str = f'{k_str}: dict({v_str}' | |||
| else: | |||
| attr_str = f'{str(k)}=dict({v_str}' | |||
| attr_str = _indent(attr_str, indent) + ')' + end | |||
| elif isinstance(v, list): | |||
| attr_str = _format_list(k, v, use_mapping) + end | |||
| else: | |||
| attr_str = _format_basic_types(k, v, use_mapping) + end | |||
| s.append(attr_str) | |||
| r += '\n'.join(s) | |||
| if use_mapping: | |||
| r += '}' | |||
| return r | |||
| cfg_dict = self._cfg_dict.to_dict() | |||
| text = _format_dict(cfg_dict, outest_level=True) | |||
| # copied from setup.cfg | |||
| yapf_style = dict( | |||
| based_on_style='pep8', | |||
| blank_line_before_nested_class_or_def=True, | |||
| split_before_expression_after_opening_paren=True) | |||
| text, _ = FormatCode(text, style_config=yapf_style, verify=True) | |||
| return text | |||
| def __repr__(self): | |||
| return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' | |||
| def __len__(self): | |||
| return len(self._cfg_dict) | |||
| def __getattr__(self, name): | |||
| return getattr(self._cfg_dict, name) | |||
| def __getitem__(self, name): | |||
| return self._cfg_dict.__getitem__(name) | |||
| def __setattr__(self, name, value): | |||
| if isinstance(value, dict): | |||
| value = ConfigDict(value) | |||
| self._cfg_dict.__setattr__(name, value) | |||
| def __setitem__(self, name, value): | |||
| if isinstance(value, dict): | |||
| value = ConfigDict(value) | |||
| self._cfg_dict.__setitem__(name, value) | |||
| def __iter__(self): | |||
| return iter(self._cfg_dict) | |||
| def __getstate__(self): | |||
| return (self._cfg_dict, self._filename, self._text) | |||
| def __copy__(self): | |||
| cls = self.__class__ | |||
| other = cls.__new__(cls) | |||
| other.__dict__.update(self.__dict__) | |||
| return other | |||
| def __deepcopy__(self, memo): | |||
| cls = self.__class__ | |||
| other = cls.__new__(cls) | |||
| memo[id(self)] = other | |||
| for key, value in self.__dict__.items(): | |||
| super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) | |||
| return other | |||
| def __setstate__(self, state): | |||
| _cfg_dict, _filename, _text = state | |||
| super(Config, self).__setattr__('_cfg_dict', _cfg_dict) | |||
| super(Config, self).__setattr__('_filename', _filename) | |||
| super(Config, self).__setattr__('_text', _text) | |||
| def dump(self, file: str = None): | |||
| """Dumps config into a file or returns a string representation of the | |||
| config. | |||
| If a file argument is given, saves the config to that file using the | |||
| format defined by the file argument extension. | |||
| Otherwise, returns a string representing the config. The formatting of | |||
| this returned string is defined by the extension of `self.filename`. If | |||
| `self.filename` is not defined, returns a string representation of a | |||
| dict (lowercased and using ' for strings). | |||
| Examples: | |||
| >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), | |||
| ... item3=True, item4='test') | |||
| >>> cfg = Config(cfg_dict=cfg_dict) | |||
| >>> dump_file = "a.py" | |||
| >>> cfg.dump(dump_file) | |||
| Args: | |||
| file (str, optional): Path of the output file where the config | |||
| will be dumped. Defaults to None. | |||
| """ | |||
| from maas_lib.fileio import dump | |||
| cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() | |||
| if file is None: | |||
| if self.filename is None or self.filename.endswith('.py'): | |||
| return self.pretty_text | |||
| else: | |||
| file_format = self.filename.split('.')[-1] | |||
| return dump(cfg_dict, file_format=file_format) | |||
| elif file.endswith('.py'): | |||
| with open(file, 'w', encoding='utf-8') as f: | |||
| f.write(self.pretty_text) | |||
| else: | |||
| file_format = file.split('.')[-1] | |||
| return dump(cfg_dict, file=file, file_format=file_format) | |||
| def merge_from_dict(self, options, allow_list_keys=True): | |||
| """Merge list into cfg_dict. | |||
| Merge the dict parsed by MultipleKVAction into this cfg. | |||
| Examples: | |||
| >>> options = {'model.backbone.depth': 50, | |||
| ... 'model.backbone.with_cp':True} | |||
| >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) | |||
| >>> cfg.merge_from_dict(options) | |||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||
| >>> assert cfg_dict == dict( | |||
| ... model=dict(backbone=dict(depth=50, with_cp=True))) | |||
| >>> # Merge list element | |||
| >>> cfg = Config(dict(pipeline=[ | |||
| ... dict(type='Resize'), dict(type='RandomDistortion')])) | |||
| >>> options = dict(pipeline={'0': dict(type='MyResize')}) | |||
| >>> cfg.merge_from_dict(options, allow_list_keys=True) | |||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||
| >>> assert cfg_dict == dict(pipeline=[ | |||
| ... dict(type='MyResize'), dict(type='RandomDistortion')]) | |||
| Args: | |||
| options (dict): dict of configs to merge from. | |||
| allow_list_keys (bool): If True, int string keys (e.g. '0', '1') | |||
| are allowed in ``options`` and will replace the element of the | |||
| corresponding index in the config if the config is a list. | |||
| Default: True. | |||
| """ | |||
| option_cfg_dict = {} | |||
| for full_key, v in options.items(): | |||
| d = option_cfg_dict | |||
| key_list = full_key.split('.') | |||
| for subkey in key_list[:-1]: | |||
| d.setdefault(subkey, ConfigDict()) | |||
| d = d[subkey] | |||
| subkey = key_list[-1] | |||
| d[subkey] = v | |||
| cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||
| super(Config, self).__setattr__( | |||
| '_cfg_dict', | |||
| Config._merge_a_into_b( | |||
| option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) | |||
| def to_dict(self) -> Dict: | |||
| """ Convert Config object to python dict | |||
| """ | |||
| return self._cfg_dict.to_dict() | |||
| def to_args(self, parse_fn, use_hyphen=True): | |||
| """ Convert config obj to args using parse_fn | |||
| Args: | |||
| parse_fn: a function object, which takes args as input, | |||
| such as ['--foo', 'FOO'] and return parsed args, an | |||
| example is given as follows | |||
| including literal blocks:: | |||
| def parse_fn(args): | |||
| parser = argparse.ArgumentParser(prog='PROG') | |||
| parser.add_argument('-x') | |||
| parser.add_argument('--foo') | |||
| return parser.parse_args(args) | |||
| use_hyphen (bool, optional): if set true, hyphen in keyname | |||
| will be converted to underscore | |||
| Return: | |||
| args: arg object parsed by argparse.ArgumentParser | |||
| """ | |||
| args = [] | |||
| for k, v in self._cfg_dict.items(): | |||
| arg_name = f'--{k}' | |||
| if use_hyphen: | |||
| arg_name = arg_name.replace('_', '-') | |||
| if isinstance(v, bool) and v: | |||
| args.append(arg_name) | |||
| elif isinstance(v, (int, str, float)): | |||
| args.append(arg_name) | |||
| args.append(str(v)) | |||
| elif isinstance(v, list): | |||
| args.append(arg_name) | |||
| assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \ | |||
| f'is expected to be either int,str,float, but got type {v[0]}' | |||
| args.append(str(v)) | |||
| else: | |||
| raise ValueError( | |||
| 'type in config file which supported to be ' | |||
| 'converted to args should be either bool, ' | |||
| f'int, str, float or list of them but got type {v}') | |||
| return parse_fn(args) | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| class Fields(object): | |||
| """ Names for different application fields | |||
| """ | |||
| image = 'image' | |||
| video = 'video' | |||
| nlp = 'nlp' | |||
| audio = 'audio' | |||
| multi_modal = 'multi_modal' | |||
| class Tasks(object): | |||
| """ Names for tasks supported by maas lib. | |||
| Holds the standard task name to use for identifying different tasks. | |||
| This should be used to register models, pipelines, trainers. | |||
| """ | |||
| # vision tasks | |||
| image_classfication = 'image-classification' | |||
| object_detection = 'object-detection' | |||
| # nlp tasks | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| fill_mask = 'fill-mask' | |||
| class InputFields(object): | |||
| """ Names for input data fileds in the input data for pipelines | |||
| """ | |||
| img = 'img' | |||
| text = 'text' | |||
| audio = 'audio' | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import logging | |||
| from typing import Optional | |||
| init_loggers = {} | |||
| def get_logger(log_file: Optional[str] = None, | |||
| log_level: int = logging.INFO, | |||
| file_mode: str = 'w'): | |||
| """ Get logging logger | |||
| Args: | |||
| log_file: Log filename, if specified, file handler will be added to | |||
| logger | |||
| log_level: Logging level. | |||
| file_mode: Specifies the mode to open the file, if filename is | |||
| specified (if filemode is unspecified, it defaults to 'w'). | |||
| """ | |||
| logger_name = __name__.split('.')[0] | |||
| logger = logging.getLogger(logger_name) | |||
| if logger_name in init_loggers: | |||
| return logger | |||
| stream_handler = logging.StreamHandler() | |||
| handlers = [stream_handler] | |||
| # TODO @wenmeng.zwm add logger setting for distributed environment | |||
| if log_file is not None: | |||
| file_handler = logging.FileHandler(log_file, file_mode) | |||
| handlers.append(file_handler) | |||
| formatter = logging.Formatter( | |||
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |||
| for handler in handlers: | |||
| handler.setFormatter(formatter) | |||
| handler.setLevel(log_level) | |||
| logger.addHandler(handler) | |||
| logger.setLevel(log_level) | |||
| init_loggers[logger_name] = True | |||
| return logger | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import ast | |||
| import os | |||
| import os.path as osp | |||
| import sys | |||
| import types | |||
| from importlib import import_module | |||
| from maas_lib.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def import_modules_from_file(py_file: str): | |||
| """ Import module from a certrain file | |||
| Args: | |||
| py_file: path to a python file to be imported | |||
| Return: | |||
| """ | |||
| dirname, basefile = os.path.split(py_file) | |||
| if dirname == '': | |||
| dirname == './' | |||
| module_name = osp.splitext(basefile)[0] | |||
| sys.path.insert(0, dirname) | |||
| validate_py_syntax(py_file) | |||
| mod = import_module(module_name) | |||
| sys.path.pop(0) | |||
| return module_name, mod | |||
| def import_modules(imports, allow_failed_imports=False): | |||
| """Import modules from the given list of strings. | |||
| Args: | |||
| imports (list | str | None): The given module names to be imported. | |||
| allow_failed_imports (bool): If True, the failed imports will return | |||
| None. Otherwise, an ImportError is raise. Default: False. | |||
| Returns: | |||
| list[module] | module | None: The imported modules. | |||
| Examples: | |||
| >>> osp, sys = import_modules( | |||
| ... ['os.path', 'sys']) | |||
| >>> import os.path as osp_ | |||
| >>> import sys as sys_ | |||
| >>> assert osp == osp_ | |||
| >>> assert sys == sys_ | |||
| """ | |||
| if not imports: | |||
| return | |||
| single_import = False | |||
| if isinstance(imports, str): | |||
| single_import = True | |||
| imports = [imports] | |||
| if not isinstance(imports, list): | |||
| raise TypeError( | |||
| f'custom_imports must be a list but got type {type(imports)}') | |||
| imported = [] | |||
| for imp in imports: | |||
| if not isinstance(imp, str): | |||
| raise TypeError( | |||
| f'{imp} is of type {type(imp)} and cannot be imported.') | |||
| try: | |||
| imported_tmp = import_module(imp) | |||
| except ImportError: | |||
| if allow_failed_imports: | |||
| logger.warning(f'{imp} failed to import and is ignored.') | |||
| imported_tmp = None | |||
| else: | |||
| raise ImportError | |||
| imported.append(imported_tmp) | |||
| if single_import: | |||
| imported = imported[0] | |||
| return imported | |||
| def validate_py_syntax(filename): | |||
| with open(filename, 'r', encoding='utf-8') as f: | |||
| # Setting encoding explicitly to resolve coding issue on windows | |||
| content = f.read() | |||
| try: | |||
| ast.parse(content) | |||
| except SyntaxError as e: | |||
| raise SyntaxError('There are syntax errors in config ' | |||
| f'file {filename}: {e}') | |||
| @@ -0,0 +1,183 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import inspect | |||
| from maas_lib.utils.logger import get_logger | |||
| default_group = 'default' | |||
| logger = get_logger() | |||
| class Registry(object): | |||
| """ Registry which support registering modules and group them by a keyname | |||
| If group name is not provided, modules will be registered to default group. | |||
| """ | |||
| def __init__(self, name: str): | |||
| self._name = name | |||
| self._modules = dict() | |||
| def __repr__(self): | |||
| format_str = self.__class__.__name__ + f'({self._name})\n' | |||
| for group_name, group in self._modules.items(): | |||
| format_str += f'group_name={group_name}, '\ | |||
| f'modules={list(group.keys())}\n' | |||
| return format_str | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| @property | |||
| def modules(self): | |||
| return self._modules | |||
| def list(self): | |||
| """ logging the list of module in current registry | |||
| """ | |||
| for group_name, group in self._modules.items(): | |||
| logger.info(f'group_name={group_name}') | |||
| for m in group.keys(): | |||
| logger.info(f'\t{m}') | |||
| logger.info('') | |||
| def get(self, module_key, group_key=default_group): | |||
| if group_key not in self._modules: | |||
| return None | |||
| else: | |||
| return self._modules[group_key].get(module_key, None) | |||
| def _register_module(self, | |||
| group_key=default_group, | |||
| module_name=None, | |||
| module_cls=None): | |||
| assert isinstance(group_key, | |||
| str), 'group_key is required and must be str' | |||
| if group_key not in self._modules: | |||
| self._modules[group_key] = dict() | |||
| if not inspect.isclass(module_cls): | |||
| raise TypeError(f'module is not a class type: {type(module_cls)}') | |||
| if module_name is None: | |||
| module_name = module_cls.__name__ | |||
| if module_name in self._modules[group_key]: | |||
| raise KeyError(f'{module_name} is already registered in' | |||
| f'{self._name}[{group_key}]') | |||
| self._modules[group_key][module_name] = module_cls | |||
| def register_module(self, | |||
| group_key: str = default_group, | |||
| module_name: str = None, | |||
| module_cls: type = None): | |||
| """ Register module | |||
| Example: | |||
| >>> models = Registry('models') | |||
| >>> @models.register_module('image-classification', 'SwinT') | |||
| >>> class SwinTransformer: | |||
| >>> pass | |||
| >>> @models.register_module('SwinDefault') | |||
| >>> class SwinTransformerDefaultGroup: | |||
| >>> pass | |||
| Args: | |||
| group_key: Group name of which module will be registered, | |||
| default group name is 'default' | |||
| module_name: Module name | |||
| module_cls: Module class object | |||
| """ | |||
| if not (module_name is None or isinstance(module_name, str)): | |||
| raise TypeError(f'module_name must be either of None, str,' | |||
| f'got {type(module_name)}') | |||
| if module_cls is not None: | |||
| self._register_module( | |||
| group_key=group_key, | |||
| module_name=module_name, | |||
| module_cls=module_cls) | |||
| return module_cls | |||
| # if module_cls is None, should return a dectorator function | |||
| def _register(module_cls): | |||
| self._register_module( | |||
| group_key=group_key, | |||
| module_name=module_name, | |||
| module_cls=module_cls) | |||
| return module_cls | |||
| return _register | |||
| def build_from_cfg(cfg, | |||
| registry: Registry, | |||
| group_key: str = default_group, | |||
| default_args: dict = None) -> object: | |||
| """Build a module from config dict when it is a class configuration, or | |||
| call a function from config dict when it is a function configuration. | |||
| Example: | |||
| >>> models = Registry('models') | |||
| >>> @models.register_module('image-classification', 'SwinT') | |||
| >>> class SwinTransformer: | |||
| >>> pass | |||
| >>> swint = build_from_cfg(dict(type='SwinT'), MODELS, | |||
| >>> 'image-classification') | |||
| >>> # Returns an instantiated object | |||
| >>> | |||
| >>> @MODELS.register_module() | |||
| >>> def swin_transformer(): | |||
| >>> pass | |||
| >>> = build_from_cfg(dict(type='swin_transformer'), MODELS) | |||
| >>> # Return a result of the calling function | |||
| Args: | |||
| cfg (dict): Config dict. It should at least contain the key "type". | |||
| registry (:obj:`Registry`): The registry to search the type from. | |||
| group_key (str, optional): The name of registry group from which | |||
| module should be searched. | |||
| default_args (dict, optional): Default initialization arguments. | |||
| Returns: | |||
| object: The constructed object. | |||
| """ | |||
| if not isinstance(cfg, dict): | |||
| raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |||
| if 'type' not in cfg: | |||
| if default_args is None or 'type' not in default_args: | |||
| raise KeyError( | |||
| '`cfg` or `default_args` must contain the key "type", ' | |||
| f'but got {cfg}\n{default_args}') | |||
| if not isinstance(registry, Registry): | |||
| raise TypeError('registry must be an maas_lib.Registry object, ' | |||
| f'but got {type(registry)}') | |||
| if not (isinstance(default_args, dict) or default_args is None): | |||
| raise TypeError('default_args must be a dict or None, ' | |||
| f'but got {type(default_args)}') | |||
| args = cfg.copy() | |||
| if default_args is not None: | |||
| for name, value in default_args.items(): | |||
| args.setdefault(name, value) | |||
| obj_type = args.pop('type') | |||
| if isinstance(obj_type, str): | |||
| obj_cls = registry.get(obj_type, group_key=group_key) | |||
| if obj_cls is None: | |||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
| f'registry group {group_key}') | |||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
| obj_cls = obj_type | |||
| else: | |||
| raise TypeError( | |||
| f'type must be a str or valid type, but got {type(obj_type)}') | |||
| try: | |||
| return obj_cls(**args) | |||
| except Exception as e: | |||
| # Normal TypeError does not print class name. | |||
| raise type(e)(f'{obj_cls.__name__}: {e}') | |||
| @@ -0,0 +1 @@ | |||
| -r requirements/runtime.txt | |||
| @@ -0,0 +1,6 @@ | |||
| docutils==0.16.0 | |||
| recommonmark | |||
| sphinx==4.0.2 | |||
| sphinx-copybutton | |||
| sphinx_markdown_tables | |||
| sphinx_rtd_theme==0.5.2 | |||
| @@ -0,0 +1,5 @@ | |||
| addict | |||
| numpy | |||
| pyyaml | |||
| requests | |||
| yapf | |||
| @@ -0,0 +1,5 @@ | |||
| expecttest | |||
| flake8 | |||
| isort==4.3.21 | |||
| pre-commit | |||
| yapf==0.30.0 | |||
| @@ -0,0 +1,24 @@ | |||
| [isort] | |||
| line_length = 79 | |||
| multi_line_output = 0 | |||
| known_standard_library = setuptools | |||
| known_first_party = maas_lib | |||
| known_third_party = json,yaml | |||
| no_lines_before = STDLIB,LOCALFOLDER | |||
| default_section = THIRDPARTY | |||
| [yapf] | |||
| BASED_ON_STYLE = pep8 | |||
| BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true | |||
| SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true | |||
| [codespell] | |||
| skip = *.ipynb | |||
| quiet-level = 3 | |||
| ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||
| [flake8] | |||
| select = B,C,E,F,P,T4,W,B9 | |||
| max-line-length = 120 | |||
| ignore = F401 | |||
| exclude = docs/src,*.pyi,.git | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| from requests import HTTPError | |||
| from maas_lib.fileio.file import File, HTTPStorage, LocalStorage | |||
| class FileTest(unittest.TestCase): | |||
| def test_local_storage(self): | |||
| storage = LocalStorage() | |||
| temp_name = tempfile.gettempdir() + '/' + next( | |||
| tempfile._get_candidate_names()) | |||
| binary_content = b'12345' | |||
| storage.write(binary_content, temp_name) | |||
| self.assertEqual(binary_content, storage.read(temp_name)) | |||
| content = '12345' | |||
| storage.write_text(content, temp_name) | |||
| self.assertEqual(content, storage.read_text(temp_name)) | |||
| os.remove(temp_name) | |||
| def test_http_storage(self): | |||
| storage = HTTPStorage() | |||
| url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||
| '/data/test/data.txt' | |||
| content = 'this is test data' | |||
| self.assertEqual(content.encode('utf8'), storage.read(url)) | |||
| self.assertEqual(content, storage.read_text(url)) | |||
| with storage.as_local_path(url) as local_file: | |||
| with open(local_file, 'r') as infile: | |||
| self.assertEqual(content, infile.read()) | |||
| with self.assertRaises(NotImplementedError): | |||
| storage.write('dfad', url) | |||
| with self.assertRaises(HTTPError): | |||
| storage.read(url + 'df') | |||
| def test_file(self): | |||
| url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com'\ | |||
| '/data/test/data.txt' | |||
| content = 'this is test data' | |||
| self.assertEqual(content.encode('utf8'), File.read(url)) | |||
| with File.as_local_path(url) as local_file: | |||
| with open(local_file, 'r') as infile: | |||
| self.assertEqual(content, infile.read()) | |||
| with self.assertRaises(NotImplementedError): | |||
| File.write('dfad', url) | |||
| with self.assertRaises(HTTPError): | |||
| File.read(url + 'df') | |||
| temp_name = tempfile.gettempdir() + '/' + next( | |||
| tempfile._get_candidate_names()) | |||
| binary_content = b'12345' | |||
| File.write(binary_content, temp_name) | |||
| self.assertEqual(binary_content, File.read(temp_name)) | |||
| os.remove(temp_name) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import tempfile | |||
| import unittest | |||
| from maas_lib.fileio.io import dump, dumps, load | |||
| class FileIOTest(unittest.TestCase): | |||
| def test_format(self, format='json'): | |||
| obj = [1, 2, 3, 'str', {'model': 'resnet'}] | |||
| result_str = dumps(obj, format) | |||
| temp_name = tempfile.gettempdir() + '/' + next( | |||
| tempfile._get_candidate_names()) + '.' + format | |||
| dump(obj, temp_name) | |||
| obj_load = load(temp_name) | |||
| self.assertEqual(obj_load, obj) | |||
| with open(temp_name, 'r') as infile: | |||
| self.assertEqual(result_str, infile.read()) | |||
| with self.assertRaises(TypeError): | |||
| obj_load = load(temp_name + 's') | |||
| with self.assertRaises(TypeError): | |||
| dump(obj, temp_name + 's') | |||
| def test_yaml(self): | |||
| self.test_format('yaml') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,53 @@ | |||
| #!/usr/bin/env python | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| import os | |||
| import sys | |||
| import unittest | |||
| from fnmatch import fnmatch | |||
| def gather_test_cases(test_dir, pattern, list_tests): | |||
| case_list = [] | |||
| for dirpath, dirnames, filenames in os.walk(test_dir): | |||
| for file in filenames: | |||
| if fnmatch(file, pattern): | |||
| case_list.append(file) | |||
| test_suite = unittest.TestSuite() | |||
| for case in case_list: | |||
| test_case = unittest.defaultTestLoader.discover( | |||
| start_dir=test_dir, pattern=case) | |||
| test_suite.addTest(test_case) | |||
| if hasattr(test_case, '__iter__'): | |||
| for subcase in test_case: | |||
| if list_tests: | |||
| print(subcase) | |||
| else: | |||
| if list_tests: | |||
| print(test_case) | |||
| return test_suite | |||
| def main(args): | |||
| runner = unittest.TextTestRunner() | |||
| test_suite = gather_test_cases( | |||
| os.path.abspath(args.test_dir), args.pattern, args.list_tests) | |||
| if not args.list_tests: | |||
| result = runner.run(test_suite) | |||
| if len(result.failures) > 0: | |||
| sys.exit(1) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser('test runner') | |||
| parser.add_argument( | |||
| '--list_tests', action='store_true', help='list all tests') | |||
| parser.add_argument( | |||
| '--pattern', default='test_*.py', help='test file pattern') | |||
| parser.add_argument( | |||
| '--test_dir', default='tests', help='directory to be tested') | |||
| args = parser.parse_args() | |||
| main(args) | |||
| @@ -0,0 +1,85 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| from pathlib import Path | |||
| from maas_lib.fileio import dump, load | |||
| from maas_lib.utils.config import Config | |||
| obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | |||
| class ConfigTest(unittest.TestCase): | |||
| def test_json(self): | |||
| config_file = 'configs/examples/config.json' | |||
| cfg = Config.from_file(config_file) | |||
| self.assertEqual(cfg.a, 1) | |||
| self.assertEqual(cfg.b, obj['b']) | |||
| def test_yaml(self): | |||
| config_file = 'configs/examples/config.yaml' | |||
| cfg = Config.from_file(config_file) | |||
| self.assertEqual(cfg.a, 1) | |||
| self.assertEqual(cfg.b, obj['b']) | |||
| def test_py(self): | |||
| config_file = 'configs/examples/config.py' | |||
| cfg = Config.from_file(config_file) | |||
| self.assertEqual(cfg.a, 1) | |||
| self.assertEqual(cfg.b, obj['b']) | |||
| def test_dump(self): | |||
| config_file = 'configs/examples/config.py' | |||
| cfg = Config.from_file(config_file) | |||
| self.assertEqual(cfg.a, 1) | |||
| self.assertEqual(cfg.b, obj['b']) | |||
| pretty_text = 'a = 1\n' | |||
| pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n" | |||
| json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}' | |||
| yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n' | |||
| with tempfile.NamedTemporaryFile(suffix='.json') as ofile: | |||
| self.assertEqual(pretty_text, cfg.dump()) | |||
| cfg.dump(ofile.name) | |||
| with open(ofile.name, 'r') as infile: | |||
| self.assertEqual(json_str, infile.read()) | |||
| with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: | |||
| cfg.dump(ofile.name) | |||
| with open(ofile.name, 'r') as infile: | |||
| self.assertEqual(yaml_str, infile.read()) | |||
| def test_to_dict(self): | |||
| config_file = 'configs/examples/config.json' | |||
| cfg = Config.from_file(config_file) | |||
| d = cfg.to_dict() | |||
| print(d) | |||
| self.assertTrue(isinstance(d, dict)) | |||
| def test_to_args(self): | |||
| def parse_fn(args): | |||
| parser = argparse.ArgumentParser(prog='PROG') | |||
| parser.add_argument('--model-dir', default='') | |||
| parser.add_argument('--lr', type=float, default=0.001) | |||
| parser.add_argument('--optimizer', default='') | |||
| parser.add_argument('--weight-decay', type=float, default=1e-7) | |||
| parser.add_argument( | |||
| '--save-checkpoint-epochs', type=int, default=30) | |||
| return parser.parse_args(args) | |||
| cfg = Config.from_file('configs/examples/plain_args.yaml') | |||
| args = cfg.to_args(parse_fn) | |||
| self.assertEqual(args.model_dir, 'path/to/model') | |||
| self.assertAlmostEqual(args.lr, 0.01) | |||
| self.assertAlmostEqual(args.weight_decay, 1e-6) | |||
| self.assertEqual(args.optimizer, 'Adam') | |||
| self.assertEqual(args.save_checkpoint_epochs, 20) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg, default_group | |||
| class RegistryTest(unittest.TestCase): | |||
| def test_register_class_no_task(self): | |||
| MODELS = Registry('models') | |||
| self.assertTrue(MODELS.name == 'models') | |||
| self.assertTrue(MODELS.modules == {}) | |||
| self.assertEqual(len(MODELS.modules), 0) | |||
| @MODELS.register_module(module_name='cls-resnet') | |||
| class ResNetForCls(object): | |||
| pass | |||
| self.assertTrue(default_group in MODELS.modules) | |||
| self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls) | |||
| def test_register_class_with_task(self): | |||
| MODELS = Registry('models') | |||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||
| class SwinTForCls(object): | |||
| pass | |||
| self.assertTrue(Tasks.image_classfication in MODELS.modules) | |||
| self.assertTrue( | |||
| MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls) | |||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||
| class BertForSentimentAnalysis(object): | |||
| pass | |||
| self.assertTrue(Tasks.sentiment_analysis in MODELS.modules) | |||
| self.assertTrue( | |||
| MODELS.get('Bert', Tasks.sentiment_analysis) is | |||
| BertForSentimentAnalysis) | |||
| @MODELS.register_module(Tasks.object_detection) | |||
| class DETR(object): | |||
| pass | |||
| self.assertTrue(Tasks.object_detection in MODELS.modules) | |||
| self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | |||
| self.assertEqual(len(MODELS.modules), 3) | |||
| def test_list(self): | |||
| MODELS = Registry('models') | |||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||
| class SwinTForCls(object): | |||
| pass | |||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||
| class BertForSentimentAnalysis(object): | |||
| pass | |||
| MODELS.list() | |||
| print(MODELS) | |||
| def test_build(self): | |||
| MODELS = Registry('models') | |||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||
| class SwinTForCls(object): | |||
| pass | |||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||
| class BertForSentimentAnalysis(object): | |||
| pass | |||
| cfg = dict(type='SwinT') | |||
| model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) | |||
| self.assertTrue(isinstance(model, SwinTForCls)) | |||
| cfg = dict(type='Bert') | |||
| model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis) | |||
| self.assertTrue(isinstance(model, BertForSentimentAnalysis)) | |||
| with self.assertRaises(KeyError): | |||
| cfg = dict(type='Bert') | |||
| model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||