* add constant * add logger module * add registry and builder module * add fileio module * add requirements and setup.cfg * add config module and tests * add citest script Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8718998master
| @@ -0,0 +1,13 @@ | |||||
| pip install -r requirements/runtime.txt | |||||
| pip install -r requirements/tests.txt | |||||
| # linter test | |||||
| # use internal project for pre-commit due to the network problem | |||||
| pre-commit run --all-files | |||||
| if [ $? -ne 0 ]; then | |||||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||||
| exit -1 | |||||
| fi | |||||
| PYTHONPATH=. python tests/run.py | |||||
| @@ -0,0 +1,126 @@ | |||||
| # Byte-compiled / optimized / DLL files | |||||
| __pycache__/ | |||||
| *.py[cod] | |||||
| *$py.class | |||||
| # C extensions | |||||
| *.so | |||||
| # Distribution / packaging | |||||
| .Python | |||||
| build/ | |||||
| develop-eggs/ | |||||
| dist/ | |||||
| downloads/ | |||||
| eggs/ | |||||
| .eggs/ | |||||
| lib/ | |||||
| lib64/ | |||||
| parts/ | |||||
| sdist/ | |||||
| var/ | |||||
| wheels/ | |||||
| *.egg-info/ | |||||
| .installed.cfg | |||||
| *.egg | |||||
| /package | |||||
| MANIFEST | |||||
| # PyInstaller | |||||
| # Usually these files are written by a python script from a template | |||||
| # before PyInstaller builds the exe, so as to inject date/other infos into it. | |||||
| *.manifest | |||||
| *.spec | |||||
| # Installer logs | |||||
| pip-log.txt | |||||
| pip-delete-this-directory.txt | |||||
| # Unit test / coverage reports | |||||
| htmlcov/ | |||||
| .tox/ | |||||
| .coverage | |||||
| .coverage.* | |||||
| .cache | |||||
| nosetests.xml | |||||
| coverage.xml | |||||
| *.cover | |||||
| .hypothesis/ | |||||
| .pytest_cache/ | |||||
| # Translations | |||||
| *.mo | |||||
| *.pot | |||||
| # Django stuff: | |||||
| *.log | |||||
| local_settings.py | |||||
| db.sqlite3 | |||||
| # Flask stuff: | |||||
| instance/ | |||||
| .webassets-cache | |||||
| # Scrapy stuff: | |||||
| .scrapy | |||||
| # Sphinx documentation | |||||
| docs/_build/ | |||||
| # PyBuilder | |||||
| target/ | |||||
| # Jupyter Notebook | |||||
| .ipynb_checkpoints | |||||
| # pyenv | |||||
| .python-version | |||||
| # celery beat schedule file | |||||
| celerybeat-schedule | |||||
| # SageMath parsed files | |||||
| *.sage.py | |||||
| # Environments | |||||
| .env | |||||
| .venv | |||||
| env/ | |||||
| venv/ | |||||
| ENV/ | |||||
| env.bak/ | |||||
| venv.bak/ | |||||
| # Spyder project settings | |||||
| .spyderproject | |||||
| .spyproject | |||||
| # Rope project settings | |||||
| .ropeproject | |||||
| # mkdocs documentation | |||||
| /site | |||||
| # mypy | |||||
| .mypy_cache/ | |||||
| data | |||||
| .vscode | |||||
| .idea | |||||
| # custom | |||||
| *.pkl | |||||
| *.pkl.json | |||||
| *.log.json | |||||
| *.whl | |||||
| *.tar.gz | |||||
| *.swp | |||||
| *.log | |||||
| *.tar.gz | |||||
| source.sh | |||||
| tensorboard.sh | |||||
| .DS_Store | |||||
| replace.sh | |||||
| # Pytorch | |||||
| *.pth | |||||
| @@ -0,0 +1 @@ | |||||
| This folder will host example configs for each model supported by maas_lib. | |||||
| @@ -0,0 +1,7 @@ | |||||
| { | |||||
| "a": 1, | |||||
| "b" : { | |||||
| "c": [1,2,3], | |||||
| "d" : "dd" | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,2 @@ | |||||
| a = 1 | |||||
| b = dict(c=[1,2,3], d='dd') | |||||
| @@ -0,0 +1,4 @@ | |||||
| a: 1 | |||||
| b: | |||||
| c: [1,2,3] | |||||
| d: dd | |||||
| @@ -0,0 +1,5 @@ | |||||
| model_dir: path/to/model | |||||
| lr: 0.01 | |||||
| optimizer: Adam | |||||
| weight_decay: 1e-6 | |||||
| save_checkpoint_epochs: 20 | |||||
| @@ -0,0 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .version import __version__ | |||||
| __all__ = ['__version__'] | |||||
| @@ -0,0 +1 @@ | |||||
| from .io import dump, dumps, load | |||||
| @@ -0,0 +1,325 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import contextlib | |||||
| import os | |||||
| import tempfile | |||||
| from abc import ABCMeta, abstractmethod | |||||
| from pathlib import Path | |||||
| from typing import Generator, Union | |||||
| import requests | |||||
| class Storage(metaclass=ABCMeta): | |||||
| """Abstract class of storage. | |||||
| All backends need to implement two apis: ``read()`` and ``read_text()``. | |||||
| ``read()`` reads the file as a byte stream and ``read_text()`` reads | |||||
| the file as texts. | |||||
| """ | |||||
| @abstractmethod | |||||
| def read(self, filepath: str): | |||||
| pass | |||||
| @abstractmethod | |||||
| def read_text(self, filepath: str): | |||||
| pass | |||||
| @abstractmethod | |||||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||||
| pass | |||||
| @abstractmethod | |||||
| def write_text(self, | |||||
| obj: str, | |||||
| filepath: Union[str, Path], | |||||
| encoding: str = 'utf-8') -> None: | |||||
| pass | |||||
| class LocalStorage(Storage): | |||||
| """Local hard disk storage""" | |||||
| def read(self, filepath: Union[str, Path]) -> bytes: | |||||
| """Read data from a given ``filepath`` with 'rb' mode. | |||||
| Args: | |||||
| filepath (str or Path): Path to read data. | |||||
| Returns: | |||||
| bytes: Expected bytes object. | |||||
| """ | |||||
| with open(filepath, 'rb') as f: | |||||
| content = f.read() | |||||
| return content | |||||
| def read_text(self, | |||||
| filepath: Union[str, Path], | |||||
| encoding: str = 'utf-8') -> str: | |||||
| """Read data from a given ``filepath`` with 'r' mode. | |||||
| Args: | |||||
| filepath (str or Path): Path to read data. | |||||
| encoding (str): The encoding format used to open the ``filepath``. | |||||
| Default: 'utf-8'. | |||||
| Returns: | |||||
| str: Expected text reading from ``filepath``. | |||||
| """ | |||||
| with open(filepath, 'r', encoding=encoding) as f: | |||||
| value_buf = f.read() | |||||
| return value_buf | |||||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||||
| """Write data to a given ``filepath`` with 'wb' mode. | |||||
| Note: | |||||
| ``put`` will create a directory if the directory of ``filepath`` | |||||
| does not exist. | |||||
| Args: | |||||
| obj (bytes): Data to be written. | |||||
| filepath (str or Path): Path to write data. | |||||
| """ | |||||
| dirname = os.path.dirname(filepath) | |||||
| if dirname and not os.path.exists(dirname): | |||||
| os.makedirs(dirname) | |||||
| with open(filepath, 'wb') as f: | |||||
| f.write(obj) | |||||
| def write_text(self, | |||||
| obj: str, | |||||
| filepath: Union[str, Path], | |||||
| encoding: str = 'utf-8') -> None: | |||||
| """Write data to a given ``filepath`` with 'w' mode. | |||||
| Note: | |||||
| ``put_text`` will create a directory if the directory of | |||||
| ``filepath`` does not exist. | |||||
| Args: | |||||
| obj (str): Data to be written. | |||||
| filepath (str or Path): Path to write data. | |||||
| encoding (str): The encoding format used to open the ``filepath``. | |||||
| Default: 'utf-8'. | |||||
| """ | |||||
| dirname = os.path.dirname(filepath) | |||||
| if dirname and not os.path.exists(dirname): | |||||
| os.makedirs(dirname) | |||||
| with open(filepath, 'w', encoding=encoding) as f: | |||||
| f.write(obj) | |||||
| @contextlib.contextmanager | |||||
| def as_local_path( | |||||
| self, | |||||
| filepath: Union[str, | |||||
| Path]) -> Generator[Union[str, Path], None, None]: | |||||
| """Only for unified API and do nothing.""" | |||||
| yield filepath | |||||
| class HTTPStorage(Storage): | |||||
| """HTTP and HTTPS storage.""" | |||||
| def read(self, url): | |||||
| r = requests.get(url) | |||||
| r.raise_for_status() | |||||
| return r.content | |||||
| def read_text(self, url): | |||||
| r = requests.get(url) | |||||
| r.raise_for_status() | |||||
| return r.text | |||||
| @contextlib.contextmanager | |||||
| def as_local_path( | |||||
| self, filepath: str) -> Generator[Union[str, Path], None, None]: | |||||
| """Download a file from ``filepath``. | |||||
| ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It | |||||
| can be called with ``with`` statement, and when exists from the | |||||
| ``with`` statement, the temporary path will be released. | |||||
| Args: | |||||
| filepath (str): Download a file from ``filepath``. | |||||
| Examples: | |||||
| >>> storage = HTTPStorage() | |||||
| >>> # After existing from the ``with`` clause, | |||||
| >>> # the path will be removed | |||||
| >>> with storage.get_local_path('http://path/to/file') as path: | |||||
| ... # do something here | |||||
| """ | |||||
| try: | |||||
| f = tempfile.NamedTemporaryFile(delete=False) | |||||
| f.write(self.read(filepath)) | |||||
| f.close() | |||||
| yield f.name | |||||
| finally: | |||||
| os.remove(f.name) | |||||
| def write(self, obj: bytes, url: Union[str, Path]) -> None: | |||||
| raise NotImplementedError('write is not supported by HTTP Storage') | |||||
| def write_text(self, | |||||
| obj: str, | |||||
| url: Union[str, Path], | |||||
| encoding: str = 'utf-8') -> None: | |||||
| raise NotImplementedError( | |||||
| 'write_text is not supported by HTTP Storage') | |||||
| class OSSStorage(Storage): | |||||
| """OSS storage.""" | |||||
| def __init__(self, oss_config_file=None): | |||||
| # read from config file or env var | |||||
| raise NotImplementedError( | |||||
| 'OSSStorage.__init__ to be implemented in the future') | |||||
| def read(self, filepath): | |||||
| raise NotImplementedError( | |||||
| 'OSSStorage.read to be implemented in the future') | |||||
| def read_text(self, filepath, encoding='utf-8'): | |||||
| raise NotImplementedError( | |||||
| 'OSSStorage.read_text to be implemented in the future') | |||||
| @contextlib.contextmanager | |||||
| def as_local_path( | |||||
| self, filepath: str) -> Generator[Union[str, Path], None, None]: | |||||
| """Download a file from ``filepath``. | |||||
| ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It | |||||
| can be called with ``with`` statement, and when exists from the | |||||
| ``with`` statement, the temporary path will be released. | |||||
| Args: | |||||
| filepath (str): Download a file from ``filepath``. | |||||
| Examples: | |||||
| >>> storage = OSSStorage() | |||||
| >>> # After existing from the ``with`` clause, | |||||
| >>> # the path will be removed | |||||
| >>> with storage.get_local_path('http://path/to/file') as path: | |||||
| ... # do something here | |||||
| """ | |||||
| try: | |||||
| f = tempfile.NamedTemporaryFile(delete=False) | |||||
| f.write(self.read(filepath)) | |||||
| f.close() | |||||
| yield f.name | |||||
| finally: | |||||
| os.remove(f.name) | |||||
| def write(self, obj: bytes, filepath: Union[str, Path]) -> None: | |||||
| raise NotImplementedError( | |||||
| 'OSSStorage.write to be implemented in the future') | |||||
| def write_text(self, | |||||
| obj: str, | |||||
| filepath: Union[str, Path], | |||||
| encoding: str = 'utf-8') -> None: | |||||
| raise NotImplementedError( | |||||
| 'OSSStorage.write_text to be implemented in the future') | |||||
| G_STORAGES = {} | |||||
| class File(object): | |||||
| _prefix_to_storage: dict = { | |||||
| 'oss': OSSStorage, | |||||
| 'http': HTTPStorage, | |||||
| 'https': HTTPStorage, | |||||
| 'local': LocalStorage, | |||||
| } | |||||
| @staticmethod | |||||
| def _get_storage(uri): | |||||
| assert isinstance(uri, | |||||
| str), f'uri should be str type, buf got {type(uri)}' | |||||
| if '://' not in uri: | |||||
| # local path | |||||
| storage_type = 'local' | |||||
| else: | |||||
| prefix, _ = uri.split('://') | |||||
| storage_type = prefix | |||||
| assert storage_type in File._prefix_to_storage, \ | |||||
| f'Unsupported uri {uri}, valid prefixs: '\ | |||||
| f'{list(File._prefix_to_storage.keys())}' | |||||
| if storage_type not in G_STORAGES: | |||||
| G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() | |||||
| return G_STORAGES[storage_type] | |||||
| @staticmethod | |||||
| def read(uri: str) -> bytes: | |||||
| """Read data from a given ``filepath`` with 'rb' mode. | |||||
| Args: | |||||
| filepath (str or Path): Path to read data. | |||||
| Returns: | |||||
| bytes: Expected bytes object. | |||||
| """ | |||||
| storage = File._get_storage(uri) | |||||
| return storage.read(uri) | |||||
| @staticmethod | |||||
| def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: | |||||
| """Read data from a given ``filepath`` with 'r' mode. | |||||
| Args: | |||||
| filepath (str or Path): Path to read data. | |||||
| encoding (str): The encoding format used to open the ``filepath``. | |||||
| Default: 'utf-8'. | |||||
| Returns: | |||||
| str: Expected text reading from ``filepath``. | |||||
| """ | |||||
| storage = File._get_storage(uri) | |||||
| return storage.read_text(uri) | |||||
| @staticmethod | |||||
| def write(obj: bytes, uri: Union[str, Path]) -> None: | |||||
| """Write data to a given ``filepath`` with 'wb' mode. | |||||
| Note: | |||||
| ``put`` will create a directory if the directory of ``filepath`` | |||||
| does not exist. | |||||
| Args: | |||||
| obj (bytes): Data to be written. | |||||
| filepath (str or Path): Path to write data. | |||||
| """ | |||||
| storage = File._get_storage(uri) | |||||
| return storage.write(obj, uri) | |||||
| @staticmethod | |||||
| def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: | |||||
| """Write data to a given ``filepath`` with 'w' mode. | |||||
| Note: | |||||
| ``put_text`` will create a directory if the directory of | |||||
| ``filepath`` does not exist. | |||||
| Args: | |||||
| obj (str): Data to be written. | |||||
| filepath (str or Path): Path to write data. | |||||
| encoding (str): The encoding format used to open the ``filepath``. | |||||
| Default: 'utf-8'. | |||||
| """ | |||||
| storage = File._get_storage(uri) | |||||
| return storage.write_text(obj, uri) | |||||
| @contextlib.contextmanager | |||||
| def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: | |||||
| """Only for unified API and do nothing.""" | |||||
| storage = File._get_storage(uri) | |||||
| with storage.as_local_path(uri) as local_path: | |||||
| yield local_path | |||||
| @@ -0,0 +1,3 @@ | |||||
| from .base import FormatHandler | |||||
| from .json import JsonHandler | |||||
| from .yaml import YamlHandler | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from abc import ABCMeta, abstractmethod | |||||
| class FormatHandler(metaclass=ABCMeta): | |||||
| # if `text_format` is True, file | |||||
| # should use text mode otherwise binary mode | |||||
| text_mode = True | |||||
| @abstractmethod | |||||
| def load(self, file, **kwargs): | |||||
| pass | |||||
| @abstractmethod | |||||
| def dump(self, obj, file, **kwargs): | |||||
| pass | |||||
| @abstractmethod | |||||
| def dumps(self, obj, **kwargs): | |||||
| pass | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import json | |||||
| import numpy as np | |||||
| from .base import FormatHandler | |||||
| def set_default(obj): | |||||
| """Set default json values for non-serializable values. | |||||
| It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. | |||||
| It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, | |||||
| etc.) into plain numbers of plain python built-in types. | |||||
| """ | |||||
| if isinstance(obj, (set, range)): | |||||
| return list(obj) | |||||
| elif isinstance(obj, np.ndarray): | |||||
| return obj.tolist() | |||||
| elif isinstance(obj, np.generic): | |||||
| return obj.item() | |||||
| raise TypeError(f'{type(obj)} is unsupported for json dump') | |||||
| class JsonHandler(FormatHandler): | |||||
| def load(self, file): | |||||
| return json.load(file) | |||||
| def dump(self, obj, file, **kwargs): | |||||
| kwargs.setdefault('default', set_default) | |||||
| json.dump(obj, file, **kwargs) | |||||
| def dumps(self, obj, **kwargs): | |||||
| kwargs.setdefault('default', set_default) | |||||
| return json.dumps(obj, **kwargs) | |||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import yaml | |||||
| try: | |||||
| from yaml import CDumper as Dumper | |||||
| from yaml import CLoader as Loader | |||||
| except ImportError: | |||||
| from yaml import Loader, Dumper # type: ignore | |||||
| from .base import FormatHandler # isort:skip | |||||
| class YamlHandler(FormatHandler): | |||||
| def load(self, file, **kwargs): | |||||
| kwargs.setdefault('Loader', Loader) | |||||
| return yaml.load(file, **kwargs) | |||||
| def dump(self, obj, file, **kwargs): | |||||
| kwargs.setdefault('Dumper', Dumper) | |||||
| yaml.dump(obj, file, **kwargs) | |||||
| def dumps(self, obj, **kwargs): | |||||
| kwargs.setdefault('Dumper', Dumper) | |||||
| return yaml.dump(obj, **kwargs) | |||||
| @@ -0,0 +1,127 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright (c) OpenMMLab. All rights reserved. | |||||
| from io import BytesIO, StringIO | |||||
| from pathlib import Path | |||||
| from .file import File | |||||
| from .format import JsonHandler, YamlHandler | |||||
| format_handlers = { | |||||
| 'json': JsonHandler(), | |||||
| 'yaml': YamlHandler(), | |||||
| 'yml': YamlHandler(), | |||||
| } | |||||
| def load(file, file_format=None, **kwargs): | |||||
| """Load data from json/yaml/pickle files. | |||||
| This method provides a unified api for loading data from serialized files. | |||||
| Args: | |||||
| file (str or :obj:`Path` or file-like object): Filename or a file-like | |||||
| object. | |||||
| file_format (str, optional): If not specified, the file format will be | |||||
| inferred from the file extension, otherwise use the specified one. | |||||
| Currently supported formats include "json", "yaml/yml". | |||||
| Examples: | |||||
| >>> load('/path/of/your/file') # file is storaged in disk | |||||
| >>> load('https://path/of/your/file') # file is storaged in Internet | |||||
| >>> load('oss://path/of/your/file') # file is storaged in petrel | |||||
| Returns: | |||||
| The content from the file. | |||||
| """ | |||||
| if isinstance(file, Path): | |||||
| file = str(file) | |||||
| if file_format is None and isinstance(file, str): | |||||
| file_format = file.split('.')[-1] | |||||
| if file_format not in format_handlers: | |||||
| raise TypeError(f'Unsupported format: {file_format}') | |||||
| handler = format_handlers[file_format] | |||||
| if isinstance(file, str): | |||||
| if handler.text_mode: | |||||
| with StringIO(File.read_text(file)) as f: | |||||
| obj = handler.load(f, **kwargs) | |||||
| else: | |||||
| with BytesIO(File.read(file)) as f: | |||||
| obj = handler.load(f, **kwargs) | |||||
| elif hasattr(file, 'read'): | |||||
| obj = handler.load(file, **kwargs) | |||||
| else: | |||||
| raise TypeError('"file" must be a filepath str or a file-object') | |||||
| return obj | |||||
| def dump(obj, file=None, file_format=None, **kwargs): | |||||
| """Dump data to json/yaml strings or files. | |||||
| This method provides a unified api for dumping data as strings or to files. | |||||
| Args: | |||||
| obj (any): The python object to be dumped. | |||||
| file (str or :obj:`Path` or file-like object, optional): If not | |||||
| specified, then the object is dumped to a str, otherwise to a file | |||||
| specified by the filename or file-like object. | |||||
| file_format (str, optional): Same as :func:`load`. | |||||
| Examples: | |||||
| >>> dump('hello world', '/path/of/your/file') # disk | |||||
| >>> dump('hello world', 'oss://path/of/your/file') # oss | |||||
| Returns: | |||||
| bool: True for success, False otherwise. | |||||
| """ | |||||
| if isinstance(file, Path): | |||||
| file = str(file) | |||||
| if file_format is None: | |||||
| if isinstance(file, str): | |||||
| file_format = file.split('.')[-1] | |||||
| elif file is None: | |||||
| raise ValueError( | |||||
| 'file_format must be specified since file is None') | |||||
| if file_format not in format_handlers: | |||||
| raise TypeError(f'Unsupported format: {file_format}') | |||||
| handler = format_handlers[file_format] | |||||
| if file is None: | |||||
| return handler.dump_to_str(obj, **kwargs) | |||||
| elif isinstance(file, str): | |||||
| if handler.text_mode: | |||||
| with StringIO() as f: | |||||
| handler.dump(obj, f, **kwargs) | |||||
| File.write_text(f.getvalue(), file) | |||||
| else: | |||||
| with BytesIO() as f: | |||||
| handler.dump(obj, f, **kwargs) | |||||
| File.write(f.getvalue(), file) | |||||
| elif hasattr(file, 'write'): | |||||
| handler.dump(obj, file, **kwargs) | |||||
| else: | |||||
| raise TypeError('"file" must be a filename str or a file-object') | |||||
| def dumps(obj, format, **kwargs): | |||||
| """Dump data to json/yaml strings or files. | |||||
| This method provides a unified api for dumping data as strings or to files. | |||||
| Args: | |||||
| obj (any): The python object to be dumped. | |||||
| format (str, optional): Same as file_format :func:`load`. | |||||
| Examples: | |||||
| >>> dumps('hello world', 'json') # disk | |||||
| >>> dumps('hello world', 'yaml') # oss | |||||
| Returns: | |||||
| bool: True for success, False otherwise. | |||||
| """ | |||||
| if format not in format_handlers: | |||||
| raise TypeError(f'Unsupported format: {format}') | |||||
| handler = format_handlers[format] | |||||
| return handler.dumps(obj, **kwargs) | |||||
| @@ -0,0 +1,472 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import ast | |||||
| import copy | |||||
| import os | |||||
| import os.path as osp | |||||
| import platform | |||||
| import shutil | |||||
| import sys | |||||
| import tempfile | |||||
| import types | |||||
| import uuid | |||||
| from importlib import import_module | |||||
| from pathlib import Path | |||||
| from typing import Dict | |||||
| import addict | |||||
| from yapf.yapflib.yapf_api import FormatCode | |||||
| from maas_lib.utils.logger import get_logger | |||||
| from maas_lib.utils.pymod import (import_modules, import_modules_from_file, | |||||
| validate_py_syntax) | |||||
| if platform.system() == 'Windows': | |||||
| import regex as re # type: ignore | |||||
| else: | |||||
| import re # type: ignore | |||||
| logger = get_logger() | |||||
| BASE_KEY = '_base_' | |||||
| DELETE_KEY = '_delete_' | |||||
| DEPRECATION_KEY = '_deprecation_' | |||||
| RESERVED_KEYS = ['filename', 'text', 'pretty_text'] | |||||
| class ConfigDict(addict.Dict): | |||||
| """ Dict which support get value through getattr | |||||
| Examples: | |||||
| >>> cdict = ConfigDict({'a':1232}) | |||||
| >>> print(cdict.a) | |||||
| 1232 | |||||
| """ | |||||
| def __missing__(self, name): | |||||
| raise KeyError(name) | |||||
| def __getattr__(self, name): | |||||
| try: | |||||
| value = super(ConfigDict, self).__getattr__(name) | |||||
| except KeyError: | |||||
| ex = AttributeError(f"'{self.__class__.__name__}' object has no " | |||||
| f"attribute '{name}'") | |||||
| except Exception as e: | |||||
| ex = e | |||||
| else: | |||||
| return value | |||||
| raise ex | |||||
| class Config: | |||||
| """A facility for config and config files. | |||||
| It supports common file formats as configs: python/json/yaml. The interface | |||||
| is the same as a dict object and also allows access config values as | |||||
| attributes. | |||||
| Example: | |||||
| >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd'))) | |||||
| >>> cfg.a | |||||
| 1 | |||||
| >>> cfg.b | |||||
| {'c': [1, 2, 3], 'd': 'dd'} | |||||
| >>> cfg.b.d | |||||
| 'dd' | |||||
| >>> cfg = Config.from_file('configs/examples/config.json') | |||||
| >>> cfg.filename | |||||
| 'configs/examples/config.json' | |||||
| >>> cfg.b | |||||
| {'c': [1, 2, 3], 'd': 'dd'} | |||||
| >>> cfg = Config.from_file('configs/examples/config.py') | |||||
| >>> cfg.filename | |||||
| "configs/examples/config.py" | |||||
| >>> cfg = Config.from_file('configs/examples/config.yaml') | |||||
| >>> cfg.filename | |||||
| "configs/examples/config.yaml" | |||||
| """ | |||||
| @staticmethod | |||||
| def _file2dict(filename): | |||||
| filename = osp.abspath(osp.expanduser(filename)) | |||||
| if not osp.exists(filename): | |||||
| raise ValueError(f'File does not exists {filename}') | |||||
| fileExtname = osp.splitext(filename)[1] | |||||
| if fileExtname not in ['.py', '.json', '.yaml', '.yml']: | |||||
| raise IOError('Only py/yml/yaml/json type are supported now!') | |||||
| with tempfile.TemporaryDirectory() as tmp_cfg_dir: | |||||
| tmp_cfg_file = tempfile.NamedTemporaryFile( | |||||
| dir=tmp_cfg_dir, suffix=fileExtname) | |||||
| if platform.system() == 'Windows': | |||||
| tmp_cfg_file.close() | |||||
| tmp_cfg_name = osp.basename(tmp_cfg_file.name) | |||||
| shutil.copyfile(filename, tmp_cfg_file.name) | |||||
| if filename.endswith('.py'): | |||||
| module_nanme, mod = import_modules_from_file( | |||||
| osp.join(tmp_cfg_dir, tmp_cfg_name)) | |||||
| cfg_dict = {} | |||||
| for name, value in mod.__dict__.items(): | |||||
| if not name.startswith('__') and \ | |||||
| not isinstance(value, types.ModuleType) and \ | |||||
| not isinstance(value, types.FunctionType): | |||||
| cfg_dict[name] = value | |||||
| # delete imported module | |||||
| del sys.modules[module_nanme] | |||||
| elif filename.endswith(('.yml', '.yaml', '.json')): | |||||
| from maas_lib.fileio import load | |||||
| cfg_dict = load(tmp_cfg_file.name) | |||||
| # close temp file | |||||
| tmp_cfg_file.close() | |||||
| cfg_text = filename + '\n' | |||||
| with open(filename, 'r', encoding='utf-8') as f: | |||||
| # Setting encoding explicitly to resolve coding issue on windows | |||||
| cfg_text += f.read() | |||||
| return cfg_dict, cfg_text | |||||
| @staticmethod | |||||
| def from_file(filename): | |||||
| if isinstance(filename, Path): | |||||
| filename = str(filename) | |||||
| cfg_dict, cfg_text = Config._file2dict(filename) | |||||
| return Config(cfg_dict, cfg_text=cfg_text, filename=filename) | |||||
| @staticmethod | |||||
| def from_string(cfg_str, file_format): | |||||
| """Generate config from config str. | |||||
| Args: | |||||
| cfg_str (str): Config str. | |||||
| file_format (str): Config file format corresponding to the | |||||
| config str. Only py/yml/yaml/json type are supported now! | |||||
| Returns: | |||||
| :obj:`Config`: Config obj. | |||||
| """ | |||||
| if file_format not in ['.py', '.json', '.yaml', '.yml']: | |||||
| raise IOError('Only py/yml/yaml/json type are supported now!') | |||||
| if file_format != '.py' and 'dict(' in cfg_str: | |||||
| # check if users specify a wrong suffix for python | |||||
| logger.warning( | |||||
| 'Please check "file_format", the file format may be .py') | |||||
| with tempfile.NamedTemporaryFile( | |||||
| 'w', encoding='utf-8', suffix=file_format, | |||||
| delete=False) as temp_file: | |||||
| temp_file.write(cfg_str) | |||||
| # on windows, previous implementation cause error | |||||
| # see PR 1077 for details | |||||
| cfg = Config.from_file(temp_file.name) | |||||
| os.remove(temp_file.name) | |||||
| return cfg | |||||
| def __init__(self, cfg_dict=None, cfg_text=None, filename=None): | |||||
| if cfg_dict is None: | |||||
| cfg_dict = dict() | |||||
| elif not isinstance(cfg_dict, dict): | |||||
| raise TypeError('cfg_dict must be a dict, but ' | |||||
| f'got {type(cfg_dict)}') | |||||
| for key in cfg_dict: | |||||
| if key in RESERVED_KEYS: | |||||
| raise KeyError(f'{key} is reserved for config file') | |||||
| if isinstance(filename, Path): | |||||
| filename = str(filename) | |||||
| super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) | |||||
| super(Config, self).__setattr__('_filename', filename) | |||||
| if cfg_text: | |||||
| text = cfg_text | |||||
| elif filename: | |||||
| with open(filename, 'r') as f: | |||||
| text = f.read() | |||||
| else: | |||||
| text = '' | |||||
| super(Config, self).__setattr__('_text', text) | |||||
| @property | |||||
| def filename(self): | |||||
| return self._filename | |||||
| @property | |||||
| def text(self): | |||||
| return self._text | |||||
| @property | |||||
| def pretty_text(self): | |||||
| indent = 4 | |||||
| def _indent(s_, num_spaces): | |||||
| s = s_.split('\n') | |||||
| if len(s) == 1: | |||||
| return s_ | |||||
| first = s.pop(0) | |||||
| s = [(num_spaces * ' ') + line for line in s] | |||||
| s = '\n'.join(s) | |||||
| s = first + '\n' + s | |||||
| return s | |||||
| def _format_basic_types(k, v, use_mapping=False): | |||||
| if isinstance(v, str): | |||||
| v_str = f"'{v}'" | |||||
| else: | |||||
| v_str = str(v) | |||||
| if use_mapping: | |||||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||||
| attr_str = f'{k_str}: {v_str}' | |||||
| else: | |||||
| attr_str = f'{str(k)}={v_str}' | |||||
| attr_str = _indent(attr_str, indent) | |||||
| return attr_str | |||||
| def _format_list(k, v, use_mapping=False): | |||||
| # check if all items in the list are dict | |||||
| if all(isinstance(_, dict) for _ in v): | |||||
| v_str = '[\n' | |||||
| v_str += '\n'.join( | |||||
| f'dict({_indent(_format_dict(v_), indent)}),' | |||||
| for v_ in v).rstrip(',') | |||||
| if use_mapping: | |||||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||||
| attr_str = f'{k_str}: {v_str}' | |||||
| else: | |||||
| attr_str = f'{str(k)}={v_str}' | |||||
| attr_str = _indent(attr_str, indent) + ']' | |||||
| else: | |||||
| attr_str = _format_basic_types(k, v, use_mapping) | |||||
| return attr_str | |||||
| def _contain_invalid_identifier(dict_str): | |||||
| contain_invalid_identifier = False | |||||
| for key_name in dict_str: | |||||
| contain_invalid_identifier |= \ | |||||
| (not str(key_name).isidentifier()) | |||||
| return contain_invalid_identifier | |||||
| def _format_dict(input_dict, outest_level=False): | |||||
| r = '' | |||||
| s = [] | |||||
| use_mapping = _contain_invalid_identifier(input_dict) | |||||
| if use_mapping: | |||||
| r += '{' | |||||
| for idx, (k, v) in enumerate(input_dict.items()): | |||||
| is_last = idx >= len(input_dict) - 1 | |||||
| end = '' if outest_level or is_last else ',' | |||||
| if isinstance(v, dict): | |||||
| v_str = '\n' + _format_dict(v) | |||||
| if use_mapping: | |||||
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |||||
| attr_str = f'{k_str}: dict({v_str}' | |||||
| else: | |||||
| attr_str = f'{str(k)}=dict({v_str}' | |||||
| attr_str = _indent(attr_str, indent) + ')' + end | |||||
| elif isinstance(v, list): | |||||
| attr_str = _format_list(k, v, use_mapping) + end | |||||
| else: | |||||
| attr_str = _format_basic_types(k, v, use_mapping) + end | |||||
| s.append(attr_str) | |||||
| r += '\n'.join(s) | |||||
| if use_mapping: | |||||
| r += '}' | |||||
| return r | |||||
| cfg_dict = self._cfg_dict.to_dict() | |||||
| text = _format_dict(cfg_dict, outest_level=True) | |||||
| # copied from setup.cfg | |||||
| yapf_style = dict( | |||||
| based_on_style='pep8', | |||||
| blank_line_before_nested_class_or_def=True, | |||||
| split_before_expression_after_opening_paren=True) | |||||
| text, _ = FormatCode(text, style_config=yapf_style, verify=True) | |||||
| return text | |||||
| def __repr__(self): | |||||
| return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' | |||||
| def __len__(self): | |||||
| return len(self._cfg_dict) | |||||
| def __getattr__(self, name): | |||||
| return getattr(self._cfg_dict, name) | |||||
| def __getitem__(self, name): | |||||
| return self._cfg_dict.__getitem__(name) | |||||
| def __setattr__(self, name, value): | |||||
| if isinstance(value, dict): | |||||
| value = ConfigDict(value) | |||||
| self._cfg_dict.__setattr__(name, value) | |||||
| def __setitem__(self, name, value): | |||||
| if isinstance(value, dict): | |||||
| value = ConfigDict(value) | |||||
| self._cfg_dict.__setitem__(name, value) | |||||
| def __iter__(self): | |||||
| return iter(self._cfg_dict) | |||||
| def __getstate__(self): | |||||
| return (self._cfg_dict, self._filename, self._text) | |||||
| def __copy__(self): | |||||
| cls = self.__class__ | |||||
| other = cls.__new__(cls) | |||||
| other.__dict__.update(self.__dict__) | |||||
| return other | |||||
| def __deepcopy__(self, memo): | |||||
| cls = self.__class__ | |||||
| other = cls.__new__(cls) | |||||
| memo[id(self)] = other | |||||
| for key, value in self.__dict__.items(): | |||||
| super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) | |||||
| return other | |||||
| def __setstate__(self, state): | |||||
| _cfg_dict, _filename, _text = state | |||||
| super(Config, self).__setattr__('_cfg_dict', _cfg_dict) | |||||
| super(Config, self).__setattr__('_filename', _filename) | |||||
| super(Config, self).__setattr__('_text', _text) | |||||
| def dump(self, file: str = None): | |||||
| """Dumps config into a file or returns a string representation of the | |||||
| config. | |||||
| If a file argument is given, saves the config to that file using the | |||||
| format defined by the file argument extension. | |||||
| Otherwise, returns a string representing the config. The formatting of | |||||
| this returned string is defined by the extension of `self.filename`. If | |||||
| `self.filename` is not defined, returns a string representation of a | |||||
| dict (lowercased and using ' for strings). | |||||
| Examples: | |||||
| >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), | |||||
| ... item3=True, item4='test') | |||||
| >>> cfg = Config(cfg_dict=cfg_dict) | |||||
| >>> dump_file = "a.py" | |||||
| >>> cfg.dump(dump_file) | |||||
| Args: | |||||
| file (str, optional): Path of the output file where the config | |||||
| will be dumped. Defaults to None. | |||||
| """ | |||||
| from maas_lib.fileio import dump | |||||
| cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() | |||||
| if file is None: | |||||
| if self.filename is None or self.filename.endswith('.py'): | |||||
| return self.pretty_text | |||||
| else: | |||||
| file_format = self.filename.split('.')[-1] | |||||
| return dump(cfg_dict, file_format=file_format) | |||||
| elif file.endswith('.py'): | |||||
| with open(file, 'w', encoding='utf-8') as f: | |||||
| f.write(self.pretty_text) | |||||
| else: | |||||
| file_format = file.split('.')[-1] | |||||
| return dump(cfg_dict, file=file, file_format=file_format) | |||||
| def merge_from_dict(self, options, allow_list_keys=True): | |||||
| """Merge list into cfg_dict. | |||||
| Merge the dict parsed by MultipleKVAction into this cfg. | |||||
| Examples: | |||||
| >>> options = {'model.backbone.depth': 50, | |||||
| ... 'model.backbone.with_cp':True} | |||||
| >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) | |||||
| >>> cfg.merge_from_dict(options) | |||||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||||
| >>> assert cfg_dict == dict( | |||||
| ... model=dict(backbone=dict(depth=50, with_cp=True))) | |||||
| >>> # Merge list element | |||||
| >>> cfg = Config(dict(pipeline=[ | |||||
| ... dict(type='Resize'), dict(type='RandomDistortion')])) | |||||
| >>> options = dict(pipeline={'0': dict(type='MyResize')}) | |||||
| >>> cfg.merge_from_dict(options, allow_list_keys=True) | |||||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||||
| >>> assert cfg_dict == dict(pipeline=[ | |||||
| ... dict(type='MyResize'), dict(type='RandomDistortion')]) | |||||
| Args: | |||||
| options (dict): dict of configs to merge from. | |||||
| allow_list_keys (bool): If True, int string keys (e.g. '0', '1') | |||||
| are allowed in ``options`` and will replace the element of the | |||||
| corresponding index in the config if the config is a list. | |||||
| Default: True. | |||||
| """ | |||||
| option_cfg_dict = {} | |||||
| for full_key, v in options.items(): | |||||
| d = option_cfg_dict | |||||
| key_list = full_key.split('.') | |||||
| for subkey in key_list[:-1]: | |||||
| d.setdefault(subkey, ConfigDict()) | |||||
| d = d[subkey] | |||||
| subkey = key_list[-1] | |||||
| d[subkey] = v | |||||
| cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||||
| super(Config, self).__setattr__( | |||||
| '_cfg_dict', | |||||
| Config._merge_a_into_b( | |||||
| option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) | |||||
| def to_dict(self) -> Dict: | |||||
| """ Convert Config object to python dict | |||||
| """ | |||||
| return self._cfg_dict.to_dict() | |||||
| def to_args(self, parse_fn, use_hyphen=True): | |||||
| """ Convert config obj to args using parse_fn | |||||
| Args: | |||||
| parse_fn: a function object, which takes args as input, | |||||
| such as ['--foo', 'FOO'] and return parsed args, an | |||||
| example is given as follows | |||||
| including literal blocks:: | |||||
| def parse_fn(args): | |||||
| parser = argparse.ArgumentParser(prog='PROG') | |||||
| parser.add_argument('-x') | |||||
| parser.add_argument('--foo') | |||||
| return parser.parse_args(args) | |||||
| use_hyphen (bool, optional): if set true, hyphen in keyname | |||||
| will be converted to underscore | |||||
| Return: | |||||
| args: arg object parsed by argparse.ArgumentParser | |||||
| """ | |||||
| args = [] | |||||
| for k, v in self._cfg_dict.items(): | |||||
| arg_name = f'--{k}' | |||||
| if use_hyphen: | |||||
| arg_name = arg_name.replace('_', '-') | |||||
| if isinstance(v, bool) and v: | |||||
| args.append(arg_name) | |||||
| elif isinstance(v, (int, str, float)): | |||||
| args.append(arg_name) | |||||
| args.append(str(v)) | |||||
| elif isinstance(v, list): | |||||
| args.append(arg_name) | |||||
| assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \ | |||||
| f'is expected to be either int,str,float, but got type {v[0]}' | |||||
| args.append(str(v)) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'type in config file which supported to be ' | |||||
| 'converted to args should be either bool, ' | |||||
| f'int, str, float or list of them but got type {v}') | |||||
| return parse_fn(args) | |||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| class Fields(object): | |||||
| """ Names for different application fields | |||||
| """ | |||||
| image = 'image' | |||||
| video = 'video' | |||||
| nlp = 'nlp' | |||||
| audio = 'audio' | |||||
| multi_modal = 'multi_modal' | |||||
| class Tasks(object): | |||||
| """ Names for tasks supported by maas lib. | |||||
| Holds the standard task name to use for identifying different tasks. | |||||
| This should be used to register models, pipelines, trainers. | |||||
| """ | |||||
| # vision tasks | |||||
| image_classfication = 'image-classification' | |||||
| object_detection = 'object-detection' | |||||
| # nlp tasks | |||||
| sentiment_analysis = 'sentiment-analysis' | |||||
| fill_mask = 'fill-mask' | |||||
| class InputFields(object): | |||||
| """ Names for input data fileds in the input data for pipelines | |||||
| """ | |||||
| img = 'img' | |||||
| text = 'text' | |||||
| audio = 'audio' | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import logging | |||||
| from typing import Optional | |||||
| init_loggers = {} | |||||
| def get_logger(log_file: Optional[str] = None, | |||||
| log_level: int = logging.INFO, | |||||
| file_mode: str = 'w'): | |||||
| """ Get logging logger | |||||
| Args: | |||||
| log_file: Log filename, if specified, file handler will be added to | |||||
| logger | |||||
| log_level: Logging level. | |||||
| file_mode: Specifies the mode to open the file, if filename is | |||||
| specified (if filemode is unspecified, it defaults to 'w'). | |||||
| """ | |||||
| logger_name = __name__.split('.')[0] | |||||
| logger = logging.getLogger(logger_name) | |||||
| if logger_name in init_loggers: | |||||
| return logger | |||||
| stream_handler = logging.StreamHandler() | |||||
| handlers = [stream_handler] | |||||
| # TODO @wenmeng.zwm add logger setting for distributed environment | |||||
| if log_file is not None: | |||||
| file_handler = logging.FileHandler(log_file, file_mode) | |||||
| handlers.append(file_handler) | |||||
| formatter = logging.Formatter( | |||||
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |||||
| for handler in handlers: | |||||
| handler.setFormatter(formatter) | |||||
| handler.setLevel(log_level) | |||||
| logger.addHandler(handler) | |||||
| logger.setLevel(log_level) | |||||
| init_loggers[logger_name] = True | |||||
| return logger | |||||
| @@ -0,0 +1,90 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import ast | |||||
| import os | |||||
| import os.path as osp | |||||
| import sys | |||||
| import types | |||||
| from importlib import import_module | |||||
| from maas_lib.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def import_modules_from_file(py_file: str): | |||||
| """ Import module from a certrain file | |||||
| Args: | |||||
| py_file: path to a python file to be imported | |||||
| Return: | |||||
| """ | |||||
| dirname, basefile = os.path.split(py_file) | |||||
| if dirname == '': | |||||
| dirname == './' | |||||
| module_name = osp.splitext(basefile)[0] | |||||
| sys.path.insert(0, dirname) | |||||
| validate_py_syntax(py_file) | |||||
| mod = import_module(module_name) | |||||
| sys.path.pop(0) | |||||
| return module_name, mod | |||||
| def import_modules(imports, allow_failed_imports=False): | |||||
| """Import modules from the given list of strings. | |||||
| Args: | |||||
| imports (list | str | None): The given module names to be imported. | |||||
| allow_failed_imports (bool): If True, the failed imports will return | |||||
| None. Otherwise, an ImportError is raise. Default: False. | |||||
| Returns: | |||||
| list[module] | module | None: The imported modules. | |||||
| Examples: | |||||
| >>> osp, sys = import_modules( | |||||
| ... ['os.path', 'sys']) | |||||
| >>> import os.path as osp_ | |||||
| >>> import sys as sys_ | |||||
| >>> assert osp == osp_ | |||||
| >>> assert sys == sys_ | |||||
| """ | |||||
| if not imports: | |||||
| return | |||||
| single_import = False | |||||
| if isinstance(imports, str): | |||||
| single_import = True | |||||
| imports = [imports] | |||||
| if not isinstance(imports, list): | |||||
| raise TypeError( | |||||
| f'custom_imports must be a list but got type {type(imports)}') | |||||
| imported = [] | |||||
| for imp in imports: | |||||
| if not isinstance(imp, str): | |||||
| raise TypeError( | |||||
| f'{imp} is of type {type(imp)} and cannot be imported.') | |||||
| try: | |||||
| imported_tmp = import_module(imp) | |||||
| except ImportError: | |||||
| if allow_failed_imports: | |||||
| logger.warning(f'{imp} failed to import and is ignored.') | |||||
| imported_tmp = None | |||||
| else: | |||||
| raise ImportError | |||||
| imported.append(imported_tmp) | |||||
| if single_import: | |||||
| imported = imported[0] | |||||
| return imported | |||||
| def validate_py_syntax(filename): | |||||
| with open(filename, 'r', encoding='utf-8') as f: | |||||
| # Setting encoding explicitly to resolve coding issue on windows | |||||
| content = f.read() | |||||
| try: | |||||
| ast.parse(content) | |||||
| except SyntaxError as e: | |||||
| raise SyntaxError('There are syntax errors in config ' | |||||
| f'file {filename}: {e}') | |||||
| @@ -0,0 +1,183 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import inspect | |||||
| from maas_lib.utils.logger import get_logger | |||||
| default_group = 'default' | |||||
| logger = get_logger() | |||||
| class Registry(object): | |||||
| """ Registry which support registering modules and group them by a keyname | |||||
| If group name is not provided, modules will be registered to default group. | |||||
| """ | |||||
| def __init__(self, name: str): | |||||
| self._name = name | |||||
| self._modules = dict() | |||||
| def __repr__(self): | |||||
| format_str = self.__class__.__name__ + f'({self._name})\n' | |||||
| for group_name, group in self._modules.items(): | |||||
| format_str += f'group_name={group_name}, '\ | |||||
| f'modules={list(group.keys())}\n' | |||||
| return format_str | |||||
| @property | |||||
| def name(self): | |||||
| return self._name | |||||
| @property | |||||
| def modules(self): | |||||
| return self._modules | |||||
| def list(self): | |||||
| """ logging the list of module in current registry | |||||
| """ | |||||
| for group_name, group in self._modules.items(): | |||||
| logger.info(f'group_name={group_name}') | |||||
| for m in group.keys(): | |||||
| logger.info(f'\t{m}') | |||||
| logger.info('') | |||||
| def get(self, module_key, group_key=default_group): | |||||
| if group_key not in self._modules: | |||||
| return None | |||||
| else: | |||||
| return self._modules[group_key].get(module_key, None) | |||||
| def _register_module(self, | |||||
| group_key=default_group, | |||||
| module_name=None, | |||||
| module_cls=None): | |||||
| assert isinstance(group_key, | |||||
| str), 'group_key is required and must be str' | |||||
| if group_key not in self._modules: | |||||
| self._modules[group_key] = dict() | |||||
| if not inspect.isclass(module_cls): | |||||
| raise TypeError(f'module is not a class type: {type(module_cls)}') | |||||
| if module_name is None: | |||||
| module_name = module_cls.__name__ | |||||
| if module_name in self._modules[group_key]: | |||||
| raise KeyError(f'{module_name} is already registered in' | |||||
| f'{self._name}[{group_key}]') | |||||
| self._modules[group_key][module_name] = module_cls | |||||
| def register_module(self, | |||||
| group_key: str = default_group, | |||||
| module_name: str = None, | |||||
| module_cls: type = None): | |||||
| """ Register module | |||||
| Example: | |||||
| >>> models = Registry('models') | |||||
| >>> @models.register_module('image-classification', 'SwinT') | |||||
| >>> class SwinTransformer: | |||||
| >>> pass | |||||
| >>> @models.register_module('SwinDefault') | |||||
| >>> class SwinTransformerDefaultGroup: | |||||
| >>> pass | |||||
| Args: | |||||
| group_key: Group name of which module will be registered, | |||||
| default group name is 'default' | |||||
| module_name: Module name | |||||
| module_cls: Module class object | |||||
| """ | |||||
| if not (module_name is None or isinstance(module_name, str)): | |||||
| raise TypeError(f'module_name must be either of None, str,' | |||||
| f'got {type(module_name)}') | |||||
| if module_cls is not None: | |||||
| self._register_module( | |||||
| group_key=group_key, | |||||
| module_name=module_name, | |||||
| module_cls=module_cls) | |||||
| return module_cls | |||||
| # if module_cls is None, should return a dectorator function | |||||
| def _register(module_cls): | |||||
| self._register_module( | |||||
| group_key=group_key, | |||||
| module_name=module_name, | |||||
| module_cls=module_cls) | |||||
| return module_cls | |||||
| return _register | |||||
| def build_from_cfg(cfg, | |||||
| registry: Registry, | |||||
| group_key: str = default_group, | |||||
| default_args: dict = None) -> object: | |||||
| """Build a module from config dict when it is a class configuration, or | |||||
| call a function from config dict when it is a function configuration. | |||||
| Example: | |||||
| >>> models = Registry('models') | |||||
| >>> @models.register_module('image-classification', 'SwinT') | |||||
| >>> class SwinTransformer: | |||||
| >>> pass | |||||
| >>> swint = build_from_cfg(dict(type='SwinT'), MODELS, | |||||
| >>> 'image-classification') | |||||
| >>> # Returns an instantiated object | |||||
| >>> | |||||
| >>> @MODELS.register_module() | |||||
| >>> def swin_transformer(): | |||||
| >>> pass | |||||
| >>> = build_from_cfg(dict(type='swin_transformer'), MODELS) | |||||
| >>> # Return a result of the calling function | |||||
| Args: | |||||
| cfg (dict): Config dict. It should at least contain the key "type". | |||||
| registry (:obj:`Registry`): The registry to search the type from. | |||||
| group_key (str, optional): The name of registry group from which | |||||
| module should be searched. | |||||
| default_args (dict, optional): Default initialization arguments. | |||||
| Returns: | |||||
| object: The constructed object. | |||||
| """ | |||||
| if not isinstance(cfg, dict): | |||||
| raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |||||
| if 'type' not in cfg: | |||||
| if default_args is None or 'type' not in default_args: | |||||
| raise KeyError( | |||||
| '`cfg` or `default_args` must contain the key "type", ' | |||||
| f'but got {cfg}\n{default_args}') | |||||
| if not isinstance(registry, Registry): | |||||
| raise TypeError('registry must be an maas_lib.Registry object, ' | |||||
| f'but got {type(registry)}') | |||||
| if not (isinstance(default_args, dict) or default_args is None): | |||||
| raise TypeError('default_args must be a dict or None, ' | |||||
| f'but got {type(default_args)}') | |||||
| args = cfg.copy() | |||||
| if default_args is not None: | |||||
| for name, value in default_args.items(): | |||||
| args.setdefault(name, value) | |||||
| obj_type = args.pop('type') | |||||
| if isinstance(obj_type, str): | |||||
| obj_cls = registry.get(obj_type, group_key=group_key) | |||||
| if obj_cls is None: | |||||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||||
| f'registry group {group_key}') | |||||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||||
| obj_cls = obj_type | |||||
| else: | |||||
| raise TypeError( | |||||
| f'type must be a str or valid type, but got {type(obj_type)}') | |||||
| try: | |||||
| return obj_cls(**args) | |||||
| except Exception as e: | |||||
| # Normal TypeError does not print class name. | |||||
| raise type(e)(f'{obj_cls.__name__}: {e}') | |||||
| @@ -0,0 +1 @@ | |||||
| -r requirements/runtime.txt | |||||
| @@ -0,0 +1,6 @@ | |||||
| docutils==0.16.0 | |||||
| recommonmark | |||||
| sphinx==4.0.2 | |||||
| sphinx-copybutton | |||||
| sphinx_markdown_tables | |||||
| sphinx_rtd_theme==0.5.2 | |||||
| @@ -0,0 +1,5 @@ | |||||
| addict | |||||
| numpy | |||||
| pyyaml | |||||
| requests | |||||
| yapf | |||||
| @@ -0,0 +1,5 @@ | |||||
| expecttest | |||||
| flake8 | |||||
| isort==4.3.21 | |||||
| pre-commit | |||||
| yapf==0.30.0 | |||||
| @@ -0,0 +1,24 @@ | |||||
| [isort] | |||||
| line_length = 79 | |||||
| multi_line_output = 0 | |||||
| known_standard_library = setuptools | |||||
| known_first_party = maas_lib | |||||
| known_third_party = json,yaml | |||||
| no_lines_before = STDLIB,LOCALFOLDER | |||||
| default_section = THIRDPARTY | |||||
| [yapf] | |||||
| BASED_ON_STYLE = pep8 | |||||
| BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true | |||||
| SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true | |||||
| [codespell] | |||||
| skip = *.ipynb | |||||
| quiet-level = 3 | |||||
| ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||||
| [flake8] | |||||
| select = B,C,E,F,P,T4,W,B9 | |||||
| max-line-length = 120 | |||||
| ignore = F401 | |||||
| exclude = docs/src,*.pyi,.git | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import tempfile | |||||
| import unittest | |||||
| from requests import HTTPError | |||||
| from maas_lib.fileio.file import File, HTTPStorage, LocalStorage | |||||
| class FileTest(unittest.TestCase): | |||||
| def test_local_storage(self): | |||||
| storage = LocalStorage() | |||||
| temp_name = tempfile.gettempdir() + '/' + next( | |||||
| tempfile._get_candidate_names()) | |||||
| binary_content = b'12345' | |||||
| storage.write(binary_content, temp_name) | |||||
| self.assertEqual(binary_content, storage.read(temp_name)) | |||||
| content = '12345' | |||||
| storage.write_text(content, temp_name) | |||||
| self.assertEqual(content, storage.read_text(temp_name)) | |||||
| os.remove(temp_name) | |||||
| def test_http_storage(self): | |||||
| storage = HTTPStorage() | |||||
| url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ | |||||
| '/data/test/data.txt' | |||||
| content = 'this is test data' | |||||
| self.assertEqual(content.encode('utf8'), storage.read(url)) | |||||
| self.assertEqual(content, storage.read_text(url)) | |||||
| with storage.as_local_path(url) as local_file: | |||||
| with open(local_file, 'r') as infile: | |||||
| self.assertEqual(content, infile.read()) | |||||
| with self.assertRaises(NotImplementedError): | |||||
| storage.write('dfad', url) | |||||
| with self.assertRaises(HTTPError): | |||||
| storage.read(url + 'df') | |||||
| def test_file(self): | |||||
| url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com'\ | |||||
| '/data/test/data.txt' | |||||
| content = 'this is test data' | |||||
| self.assertEqual(content.encode('utf8'), File.read(url)) | |||||
| with File.as_local_path(url) as local_file: | |||||
| with open(local_file, 'r') as infile: | |||||
| self.assertEqual(content, infile.read()) | |||||
| with self.assertRaises(NotImplementedError): | |||||
| File.write('dfad', url) | |||||
| with self.assertRaises(HTTPError): | |||||
| File.read(url + 'df') | |||||
| temp_name = tempfile.gettempdir() + '/' + next( | |||||
| tempfile._get_candidate_names()) | |||||
| binary_content = b'12345' | |||||
| File.write(binary_content, temp_name) | |||||
| self.assertEqual(binary_content, File.read(temp_name)) | |||||
| os.remove(temp_name) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import tempfile | |||||
| import unittest | |||||
| from maas_lib.fileio.io import dump, dumps, load | |||||
| class FileIOTest(unittest.TestCase): | |||||
| def test_format(self, format='json'): | |||||
| obj = [1, 2, 3, 'str', {'model': 'resnet'}] | |||||
| result_str = dumps(obj, format) | |||||
| temp_name = tempfile.gettempdir() + '/' + next( | |||||
| tempfile._get_candidate_names()) + '.' + format | |||||
| dump(obj, temp_name) | |||||
| obj_load = load(temp_name) | |||||
| self.assertEqual(obj_load, obj) | |||||
| with open(temp_name, 'r') as infile: | |||||
| self.assertEqual(result_str, infile.read()) | |||||
| with self.assertRaises(TypeError): | |||||
| obj_load = load(temp_name + 's') | |||||
| with self.assertRaises(TypeError): | |||||
| dump(obj, temp_name + 's') | |||||
| def test_yaml(self): | |||||
| self.test_format('yaml') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,53 @@ | |||||
| #!/usr/bin/env python | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import argparse | |||||
| import os | |||||
| import sys | |||||
| import unittest | |||||
| from fnmatch import fnmatch | |||||
| def gather_test_cases(test_dir, pattern, list_tests): | |||||
| case_list = [] | |||||
| for dirpath, dirnames, filenames in os.walk(test_dir): | |||||
| for file in filenames: | |||||
| if fnmatch(file, pattern): | |||||
| case_list.append(file) | |||||
| test_suite = unittest.TestSuite() | |||||
| for case in case_list: | |||||
| test_case = unittest.defaultTestLoader.discover( | |||||
| start_dir=test_dir, pattern=case) | |||||
| test_suite.addTest(test_case) | |||||
| if hasattr(test_case, '__iter__'): | |||||
| for subcase in test_case: | |||||
| if list_tests: | |||||
| print(subcase) | |||||
| else: | |||||
| if list_tests: | |||||
| print(test_case) | |||||
| return test_suite | |||||
| def main(args): | |||||
| runner = unittest.TextTestRunner() | |||||
| test_suite = gather_test_cases( | |||||
| os.path.abspath(args.test_dir), args.pattern, args.list_tests) | |||||
| if not args.list_tests: | |||||
| result = runner.run(test_suite) | |||||
| if len(result.failures) > 0: | |||||
| sys.exit(1) | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser('test runner') | |||||
| parser.add_argument( | |||||
| '--list_tests', action='store_true', help='list all tests') | |||||
| parser.add_argument( | |||||
| '--pattern', default='test_*.py', help='test file pattern') | |||||
| parser.add_argument( | |||||
| '--test_dir', default='tests', help='directory to be tested') | |||||
| args = parser.parse_args() | |||||
| main(args) | |||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import argparse | |||||
| import os.path as osp | |||||
| import tempfile | |||||
| import unittest | |||||
| from pathlib import Path | |||||
| from maas_lib.fileio import dump, load | |||||
| from maas_lib.utils.config import Config | |||||
| obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | |||||
| class ConfigTest(unittest.TestCase): | |||||
| def test_json(self): | |||||
| config_file = 'configs/examples/config.json' | |||||
| cfg = Config.from_file(config_file) | |||||
| self.assertEqual(cfg.a, 1) | |||||
| self.assertEqual(cfg.b, obj['b']) | |||||
| def test_yaml(self): | |||||
| config_file = 'configs/examples/config.yaml' | |||||
| cfg = Config.from_file(config_file) | |||||
| self.assertEqual(cfg.a, 1) | |||||
| self.assertEqual(cfg.b, obj['b']) | |||||
| def test_py(self): | |||||
| config_file = 'configs/examples/config.py' | |||||
| cfg = Config.from_file(config_file) | |||||
| self.assertEqual(cfg.a, 1) | |||||
| self.assertEqual(cfg.b, obj['b']) | |||||
| def test_dump(self): | |||||
| config_file = 'configs/examples/config.py' | |||||
| cfg = Config.from_file(config_file) | |||||
| self.assertEqual(cfg.a, 1) | |||||
| self.assertEqual(cfg.b, obj['b']) | |||||
| pretty_text = 'a = 1\n' | |||||
| pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n" | |||||
| json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}' | |||||
| yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n' | |||||
| with tempfile.NamedTemporaryFile(suffix='.json') as ofile: | |||||
| self.assertEqual(pretty_text, cfg.dump()) | |||||
| cfg.dump(ofile.name) | |||||
| with open(ofile.name, 'r') as infile: | |||||
| self.assertEqual(json_str, infile.read()) | |||||
| with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: | |||||
| cfg.dump(ofile.name) | |||||
| with open(ofile.name, 'r') as infile: | |||||
| self.assertEqual(yaml_str, infile.read()) | |||||
| def test_to_dict(self): | |||||
| config_file = 'configs/examples/config.json' | |||||
| cfg = Config.from_file(config_file) | |||||
| d = cfg.to_dict() | |||||
| print(d) | |||||
| self.assertTrue(isinstance(d, dict)) | |||||
| def test_to_args(self): | |||||
| def parse_fn(args): | |||||
| parser = argparse.ArgumentParser(prog='PROG') | |||||
| parser.add_argument('--model-dir', default='') | |||||
| parser.add_argument('--lr', type=float, default=0.001) | |||||
| parser.add_argument('--optimizer', default='') | |||||
| parser.add_argument('--weight-decay', type=float, default=1e-7) | |||||
| parser.add_argument( | |||||
| '--save-checkpoint-epochs', type=int, default=30) | |||||
| return parser.parse_args(args) | |||||
| cfg = Config.from_file('configs/examples/plain_args.yaml') | |||||
| args = cfg.to_args(parse_fn) | |||||
| self.assertEqual(args.model_dir, 'path/to/model') | |||||
| self.assertAlmostEqual(args.lr, 0.01) | |||||
| self.assertAlmostEqual(args.weight_decay, 1e-6) | |||||
| self.assertEqual(args.optimizer, 'Adam') | |||||
| self.assertEqual(args.save_checkpoint_epochs, 20) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from maas_lib.utils.registry import Registry, build_from_cfg, default_group | |||||
| class RegistryTest(unittest.TestCase): | |||||
| def test_register_class_no_task(self): | |||||
| MODELS = Registry('models') | |||||
| self.assertTrue(MODELS.name == 'models') | |||||
| self.assertTrue(MODELS.modules == {}) | |||||
| self.assertEqual(len(MODELS.modules), 0) | |||||
| @MODELS.register_module(module_name='cls-resnet') | |||||
| class ResNetForCls(object): | |||||
| pass | |||||
| self.assertTrue(default_group in MODELS.modules) | |||||
| self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls) | |||||
| def test_register_class_with_task(self): | |||||
| MODELS = Registry('models') | |||||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||||
| class SwinTForCls(object): | |||||
| pass | |||||
| self.assertTrue(Tasks.image_classfication in MODELS.modules) | |||||
| self.assertTrue( | |||||
| MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls) | |||||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||||
| class BertForSentimentAnalysis(object): | |||||
| pass | |||||
| self.assertTrue(Tasks.sentiment_analysis in MODELS.modules) | |||||
| self.assertTrue( | |||||
| MODELS.get('Bert', Tasks.sentiment_analysis) is | |||||
| BertForSentimentAnalysis) | |||||
| @MODELS.register_module(Tasks.object_detection) | |||||
| class DETR(object): | |||||
| pass | |||||
| self.assertTrue(Tasks.object_detection in MODELS.modules) | |||||
| self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | |||||
| self.assertEqual(len(MODELS.modules), 3) | |||||
| def test_list(self): | |||||
| MODELS = Registry('models') | |||||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||||
| class SwinTForCls(object): | |||||
| pass | |||||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||||
| class BertForSentimentAnalysis(object): | |||||
| pass | |||||
| MODELS.list() | |||||
| print(MODELS) | |||||
| def test_build(self): | |||||
| MODELS = Registry('models') | |||||
| @MODELS.register_module(Tasks.image_classfication, 'SwinT') | |||||
| class SwinTForCls(object): | |||||
| pass | |||||
| @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') | |||||
| class BertForSentimentAnalysis(object): | |||||
| pass | |||||
| cfg = dict(type='SwinT') | |||||
| model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) | |||||
| self.assertTrue(isinstance(model, SwinTForCls)) | |||||
| cfg = dict(type='Bert') | |||||
| model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis) | |||||
| self.assertTrue(isinstance(model, BertForSentimentAnalysis)) | |||||
| with self.assertRaises(KeyError): | |||||
| cfg = dict(type='Bert') | |||||
| model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||