From 0a756f6a0d47ec3dd9a99ceec9a9ad8d456af19f Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Tue, 17 May 2022 10:15:00 +0800 Subject: [PATCH] [to #41402703] add basic modules * add constant * add logger module * add registry and builder module * add fileio module * add requirements and setup.cfg * add config module and tests * add citest script Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8718998 --- .dev_scripts/citest.sh | 13 + .gitignore | 126 ++++++++ configs/README.md | 1 + configs/examples/config.json | 7 + configs/examples/config.py | 2 + configs/examples/config.yaml | 4 + configs/examples/plain_args.yaml | 5 + maas_lib/__init__.py | 4 + maas_lib/fileio/__init__.py | 1 + maas_lib/fileio/file.py | 325 ++++++++++++++++++++ maas_lib/fileio/format/__init__.py | 3 + maas_lib/fileio/format/base.py | 20 ++ maas_lib/fileio/format/json.py | 35 +++ maas_lib/fileio/format/yaml.py | 25 ++ maas_lib/fileio/io.py | 127 ++++++++ maas_lib/utils/config.py | 472 +++++++++++++++++++++++++++++ maas_lib/utils/constant.py | 34 +++ maas_lib/utils/logger.py | 45 +++ maas_lib/utils/pymod.py | 90 ++++++ maas_lib/utils/registry.py | 183 +++++++++++ requirements.txt | 1 + requirements/docs.txt | 6 + requirements/runtime.txt | 5 + requirements/tests.txt | 5 + setup.cfg | 24 ++ tests/__init__.py | 0 tests/fileio/__init__.py | 0 tests/fileio/test_file.py | 70 +++++ tests/fileio/test_io.py | 32 ++ tests/run.py | 53 ++++ tests/utils/__init__.py | 0 tests/utils/test_config.py | 85 ++++++ tests/utils/test_registry.py | 91 ++++++ 33 files changed, 1894 insertions(+) create mode 100644 .dev_scripts/citest.sh create mode 100644 .gitignore create mode 100644 configs/README.md create mode 100644 configs/examples/config.json create mode 100644 configs/examples/config.py create mode 100644 configs/examples/config.yaml create mode 100644 configs/examples/plain_args.yaml create mode 100644 maas_lib/fileio/__init__.py create mode 100644 maas_lib/fileio/file.py create mode 100644 maas_lib/fileio/format/__init__.py create mode 100644 maas_lib/fileio/format/base.py create mode 100644 maas_lib/fileio/format/json.py create mode 100644 maas_lib/fileio/format/yaml.py create mode 100644 maas_lib/fileio/io.py create mode 100644 maas_lib/utils/config.py create mode 100644 maas_lib/utils/constant.py create mode 100644 maas_lib/utils/logger.py create mode 100644 maas_lib/utils/pymod.py create mode 100644 maas_lib/utils/registry.py create mode 100644 requirements.txt create mode 100644 requirements/docs.txt create mode 100644 requirements/runtime.txt create mode 100644 requirements/tests.txt create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/fileio/__init__.py create mode 100644 tests/fileio/test_file.py create mode 100644 tests/fileio/test_io.py create mode 100644 tests/run.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_config.py create mode 100644 tests/utils/test_registry.py diff --git a/.dev_scripts/citest.sh b/.dev_scripts/citest.sh new file mode 100644 index 00000000..e487869c --- /dev/null +++ b/.dev_scripts/citest.sh @@ -0,0 +1,13 @@ +pip install -r requirements/runtime.txt +pip install -r requirements/tests.txt + + +# linter test +# use internal project for pre-commit due to the network problem +pre-commit run --all-files +if [ $? -ne 0 ]; then + echo "linter test failed, please run 'pre-commit run --all-files' to check" + exit -1 +fi + +PYTHONPATH=. python tests/run.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3e6a3f4a --- /dev/null +++ b/.gitignore @@ -0,0 +1,126 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +/package +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +data +.vscode +.idea + +# custom +*.pkl +*.pkl.json +*.log.json +*.whl +*.tar.gz +*.swp +*.log +*.tar.gz +source.sh +tensorboard.sh +.DS_Store +replace.sh + +# Pytorch +*.pth diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 00000000..94499da7 --- /dev/null +++ b/configs/README.md @@ -0,0 +1 @@ +This folder will host example configs for each model supported by maas_lib. diff --git a/configs/examples/config.json b/configs/examples/config.json new file mode 100644 index 00000000..551c7a50 --- /dev/null +++ b/configs/examples/config.json @@ -0,0 +1,7 @@ +{ + "a": 1, + "b" : { + "c": [1,2,3], + "d" : "dd" + } +} diff --git a/configs/examples/config.py b/configs/examples/config.py new file mode 100644 index 00000000..aab2c3c5 --- /dev/null +++ b/configs/examples/config.py @@ -0,0 +1,2 @@ +a = 1 +b = dict(c=[1,2,3], d='dd') diff --git a/configs/examples/config.yaml b/configs/examples/config.yaml new file mode 100644 index 00000000..d69dfed3 --- /dev/null +++ b/configs/examples/config.yaml @@ -0,0 +1,4 @@ +a: 1 +b: + c: [1,2,3] + d: dd diff --git a/configs/examples/plain_args.yaml b/configs/examples/plain_args.yaml new file mode 100644 index 00000000..0698b089 --- /dev/null +++ b/configs/examples/plain_args.yaml @@ -0,0 +1,5 @@ +model_dir: path/to/model +lr: 0.01 +optimizer: Adam +weight_decay: 1e-6 +save_checkpoint_epochs: 20 diff --git a/maas_lib/__init__.py b/maas_lib/__init__.py index e69de29b..0746d0e6 100644 --- a/maas_lib/__init__.py +++ b/maas_lib/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .version import __version__ + +__all__ = ['__version__'] diff --git a/maas_lib/fileio/__init__.py b/maas_lib/fileio/__init__.py new file mode 100644 index 00000000..9b85cb5f --- /dev/null +++ b/maas_lib/fileio/__init__.py @@ -0,0 +1 @@ +from .io import dump, dumps, load diff --git a/maas_lib/fileio/file.py b/maas_lib/fileio/file.py new file mode 100644 index 00000000..70820198 --- /dev/null +++ b/maas_lib/fileio/file.py @@ -0,0 +1,325 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import contextlib +import os +import tempfile +from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Generator, Union + +import requests + + +class Storage(metaclass=ABCMeta): + """Abstract class of storage. + + All backends need to implement two apis: ``read()`` and ``read_text()``. + ``read()`` reads the file as a byte stream and ``read_text()`` reads + the file as texts. + """ + + @abstractmethod + def read(self, filepath: str): + pass + + @abstractmethod + def read_text(self, filepath: str): + pass + + @abstractmethod + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + pass + + @abstractmethod + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + pass + + +class LocalStorage(Storage): + """Local hard disk storage""" + + def read(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + content = f.read() + return content + + def read_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + with open(filepath, 'wb') as f: + f.write(obj) + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + @contextlib.contextmanager + def as_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + yield filepath + + +class HTTPStorage(Storage): + """HTTP and HTTPS storage.""" + + def read(self, url): + r = requests.get(url) + r.raise_for_status() + return r.content + + def read_text(self, url): + r = requests.get(url) + r.raise_for_status() + return r.text + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = HTTPStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, url: Union[str, Path]) -> None: + raise NotImplementedError('write is not supported by HTTP Storage') + + def write_text(self, + obj: str, + url: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'write_text is not supported by HTTP Storage') + + +class OSSStorage(Storage): + """OSS storage.""" + + def __init__(self, oss_config_file=None): + # read from config file or env var + raise NotImplementedError( + 'OSSStorage.__init__ to be implemented in the future') + + def read(self, filepath): + raise NotImplementedError( + 'OSSStorage.read to be implemented in the future') + + def read_text(self, filepath, encoding='utf-8'): + raise NotImplementedError( + 'OSSStorage.read_text to be implemented in the future') + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = OSSStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + raise NotImplementedError( + 'OSSStorage.write to be implemented in the future') + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'OSSStorage.write_text to be implemented in the future') + + +G_STORAGES = {} + + +class File(object): + _prefix_to_storage: dict = { + 'oss': OSSStorage, + 'http': HTTPStorage, + 'https': HTTPStorage, + 'local': LocalStorage, + } + + @staticmethod + def _get_storage(uri): + assert isinstance(uri, + str), f'uri should be str type, buf got {type(uri)}' + + if '://' not in uri: + # local path + storage_type = 'local' + else: + prefix, _ = uri.split('://') + storage_type = prefix + + assert storage_type in File._prefix_to_storage, \ + f'Unsupported uri {uri}, valid prefixs: '\ + f'{list(File._prefix_to_storage.keys())}' + + if storage_type not in G_STORAGES: + G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() + + return G_STORAGES[storage_type] + + @staticmethod + def read(uri: str) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + storage = File._get_storage(uri) + return storage.read(uri) + + @staticmethod + def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + storage = File._get_storage(uri) + return storage.read_text(uri) + + @staticmethod + def write(obj: bytes, uri: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + storage = File._get_storage(uri) + return storage.write(obj, uri) + + @staticmethod + def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + storage = File._get_storage(uri) + return storage.write_text(obj, uri) + + @contextlib.contextmanager + def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + storage = File._get_storage(uri) + with storage.as_local_path(uri) as local_path: + yield local_path diff --git a/maas_lib/fileio/format/__init__.py b/maas_lib/fileio/format/__init__.py new file mode 100644 index 00000000..52e64279 --- /dev/null +++ b/maas_lib/fileio/format/__init__.py @@ -0,0 +1,3 @@ +from .base import FormatHandler +from .json import JsonHandler +from .yaml import YamlHandler diff --git a/maas_lib/fileio/format/base.py b/maas_lib/fileio/format/base.py new file mode 100644 index 00000000..6303c3b3 --- /dev/null +++ b/maas_lib/fileio/format/base.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABCMeta, abstractmethod + + +class FormatHandler(metaclass=ABCMeta): + # if `text_format` is True, file + # should use text mode otherwise binary mode + text_mode = True + + @abstractmethod + def load(self, file, **kwargs): + pass + + @abstractmethod + def dump(self, obj, file, **kwargs): + pass + + @abstractmethod + def dumps(self, obj, **kwargs): + pass diff --git a/maas_lib/fileio/format/json.py b/maas_lib/fileio/format/json.py new file mode 100644 index 00000000..977a8b8c --- /dev/null +++ b/maas_lib/fileio/format/json.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import numpy as np + +from .base import FormatHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(FormatHandler): + + def load(self, file): + return json.load(file) + + def dump(self, obj, file, **kwargs): + kwargs.setdefault('default', set_default) + json.dump(obj, file, **kwargs) + + def dumps(self, obj, **kwargs): + kwargs.setdefault('default', set_default) + return json.dumps(obj, **kwargs) diff --git a/maas_lib/fileio/format/yaml.py b/maas_lib/fileio/format/yaml.py new file mode 100644 index 00000000..783af7f3 --- /dev/null +++ b/maas_lib/fileio/format/yaml.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import yaml + +try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from .base import FormatHandler # isort:skip + + +class YamlHandler(FormatHandler): + + def load(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dumps(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/maas_lib/fileio/io.py b/maas_lib/fileio/io.py new file mode 100644 index 00000000..7f8ddd93 --- /dev/null +++ b/maas_lib/fileio/io.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO +from pathlib import Path + +from .file import File +from .format import JsonHandler, YamlHandler + +format_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), +} + + +def load(file, file_format=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml". + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('oss://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split('.')[-1] + if file_format not in format_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = format_handlers[file_format] + if isinstance(file, str): + if handler.text_mode: + with StringIO(File.read_text(file)) as f: + obj = handler.load(f, **kwargs) + else: + with BytesIO(File.read(file)) as f: + obj = handler.load(f, **kwargs) + elif hasattr(file, 'read'): + obj = handler.load(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj, file=None, file_format=None, **kwargs): + """Dump data to json/yaml strings or files. + + This method provides a unified api for dumping data as strings or to files. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 'oss://path/of/your/file') # oss + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split('.')[-1] + elif file is None: + raise ValueError( + 'file_format must be specified since file is None') + if file_format not in format_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = format_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + if handler.text_mode: + with StringIO() as f: + handler.dump(obj, f, **kwargs) + File.write_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump(obj, f, **kwargs) + File.write(f.getvalue(), file) + elif hasattr(file, 'write'): + handler.dump(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def dumps(obj, format, **kwargs): + """Dump data to json/yaml strings or files. + + This method provides a unified api for dumping data as strings or to files. + + Args: + obj (any): The python object to be dumped. + format (str, optional): Same as file_format :func:`load`. + + Examples: + >>> dumps('hello world', 'json') # disk + >>> dumps('hello world', 'yaml') # oss + + Returns: + bool: True for success, False otherwise. + """ + if format not in format_handlers: + raise TypeError(f'Unsupported format: {format}') + + handler = format_handlers[format] + return handler.dumps(obj, **kwargs) diff --git a/maas_lib/utils/config.py b/maas_lib/utils/config.py new file mode 100644 index 00000000..7d67d248 --- /dev/null +++ b/maas_lib/utils/config.py @@ -0,0 +1,472 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import types +import uuid +from importlib import import_module +from pathlib import Path +from typing import Dict + +import addict +from yapf.yapflib.yapf_api import FormatCode + +from maas_lib.utils.logger import get_logger +from maas_lib.utils.pymod import (import_modules, import_modules_from_file, + validate_py_syntax) + +if platform.system() == 'Windows': + import regex as re # type: ignore +else: + import re # type: ignore + +logger = get_logger() + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = ['filename', 'text', 'pretty_text'] + + +class ConfigDict(addict.Dict): + """ Dict which support get value through getattr + + Examples: + >>> cdict = ConfigDict({'a':1232}) + >>> print(cdict.a) + 1232 + """ + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd'))) + >>> cfg.a + 1 + >>> cfg.b + {'c': [1, 2, 3], 'd': 'dd'} + >>> cfg.b.d + 'dd' + >>> cfg = Config.from_file('configs/examples/config.json') + >>> cfg.filename + 'configs/examples/config.json' + >>> cfg.b + {'c': [1, 2, 3], 'd': 'dd'} + >>> cfg = Config.from_file('configs/examples/config.py') + >>> cfg.filename + "configs/examples/config.py" + >>> cfg = Config.from_file('configs/examples/config.yaml') + >>> cfg.filename + "configs/examples/config.yaml" + """ + + @staticmethod + def _file2dict(filename): + filename = osp.abspath(osp.expanduser(filename)) + if not osp.exists(filename): + raise ValueError(f'File does not exists {filename}') + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as tmp_cfg_dir: + tmp_cfg_file = tempfile.NamedTemporaryFile( + dir=tmp_cfg_dir, suffix=fileExtname) + if platform.system() == 'Windows': + tmp_cfg_file.close() + tmp_cfg_name = osp.basename(tmp_cfg_file.name) + shutil.copyfile(filename, tmp_cfg_file.name) + + if filename.endswith('.py'): + module_nanme, mod = import_modules_from_file( + osp.join(tmp_cfg_dir, tmp_cfg_name)) + cfg_dict = {} + for name, value in mod.__dict__.items(): + if not name.startswith('__') and \ + not isinstance(value, types.ModuleType) and \ + not isinstance(value, types.FunctionType): + cfg_dict[name] = value + + # delete imported module + del sys.modules[module_nanme] + elif filename.endswith(('.yml', '.yaml', '.json')): + from maas_lib.fileio import load + cfg_dict = load(tmp_cfg_file.name) + # close temp file + tmp_cfg_file.close() + + cfg_text = filename + '\n' + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + return cfg_dict, cfg_text + + @staticmethod + def from_file(filename): + if isinstance(filename, Path): + filename = str(filename) + cfg_dict, cfg_text = Config._file2dict(filename) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def from_string(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + :obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + logger.warning( + 'Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix=file_format, + delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.from_file(temp_file.name) + os.remove(temp_file.name) + return cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + if isinstance(filename, Path): + filename = str(filename) + + super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + super(Config, self).__setattr__('_filename', filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, 'r') as f: + text = f.read() + else: + text = '' + super(Config, self).__setattr__('_text', text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = '[\n' + v_str += '\n'.join( + f'dict({_indent(_format_dict(v_), indent)}),' + for v_ in v).rstrip(',') + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + ']' + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __copy__(self): + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + + return other + + def __deepcopy__(self, memo): + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) + + return other + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) + + def dump(self, file: str = None): + """Dumps config into a file or returns a string representation of the + config. + + If a file argument is given, saves the config to that file using the + format defined by the file argument extension. + + Otherwise, returns a string representing the config. The formatting of + this returned string is defined by the extension of `self.filename`. If + `self.filename` is not defined, returns a string representation of a + dict (lowercased and using ' for strings). + + Examples: + >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), + ... item3=True, item4='test') + >>> cfg = Config(cfg_dict=cfg_dict) + >>> dump_file = "a.py" + >>> cfg.dump(dump_file) + + Args: + file (str, optional): Path of the output file where the config + will be dumped. Defaults to None. + """ + from maas_lib.fileio import dump + cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() + if file is None: + if self.filename is None or self.filename.endswith('.py'): + return self.pretty_text + else: + file_format = self.filename.split('.')[-1] + return dump(cfg_dict, file_format=file_format) + elif file.endswith('.py'): + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + file_format = file.split('.')[-1] + return dump(cfg_dict, file=file, file_format=file_format) + + def merge_from_dict(self, options, allow_list_keys=True): + """Merge list into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(depth=50, with_cp=True))) + + >>> # Merge list element + >>> cfg = Config(dict(pipeline=[ + ... dict(type='Resize'), dict(type='RandomDistortion')])) + >>> options = dict(pipeline={'0': dict(type='MyResize')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='MyResize'), dict(type='RandomDistortion')]) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Default: True. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split('.') + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + super(Config, self).__setattr__( + '_cfg_dict', + Config._merge_a_into_b( + option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + + def to_dict(self) -> Dict: + """ Convert Config object to python dict + """ + return self._cfg_dict.to_dict() + + def to_args(self, parse_fn, use_hyphen=True): + """ Convert config obj to args using parse_fn + + Args: + parse_fn: a function object, which takes args as input, + such as ['--foo', 'FOO'] and return parsed args, an + example is given as follows + including literal blocks:: + def parse_fn(args): + parser = argparse.ArgumentParser(prog='PROG') + parser.add_argument('-x') + parser.add_argument('--foo') + return parser.parse_args(args) + use_hyphen (bool, optional): if set true, hyphen in keyname + will be converted to underscore + Return: + args: arg object parsed by argparse.ArgumentParser + """ + args = [] + for k, v in self._cfg_dict.items(): + arg_name = f'--{k}' + if use_hyphen: + arg_name = arg_name.replace('_', '-') + if isinstance(v, bool) and v: + args.append(arg_name) + elif isinstance(v, (int, str, float)): + args.append(arg_name) + args.append(str(v)) + elif isinstance(v, list): + args.append(arg_name) + assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \ + f'is expected to be either int,str,float, but got type {v[0]}' + args.append(str(v)) + else: + raise ValueError( + 'type in config file which supported to be ' + 'converted to args should be either bool, ' + f'int, str, float or list of them but got type {v}') + + return parse_fn(args) diff --git a/maas_lib/utils/constant.py b/maas_lib/utils/constant.py new file mode 100644 index 00000000..377d59d7 --- /dev/null +++ b/maas_lib/utils/constant.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +class Fields(object): + """ Names for different application fields + """ + image = 'image' + video = 'video' + nlp = 'nlp' + audio = 'audio' + multi_modal = 'multi_modal' + + +class Tasks(object): + """ Names for tasks supported by maas lib. + + Holds the standard task name to use for identifying different tasks. + This should be used to register models, pipelines, trainers. + """ + # vision tasks + image_classfication = 'image-classification' + object_detection = 'object-detection' + + # nlp tasks + sentiment_analysis = 'sentiment-analysis' + fill_mask = 'fill-mask' + + +class InputFields(object): + """ Names for input data fileds in the input data for pipelines + """ + img = 'img' + text = 'text' + audio = 'audio' diff --git a/maas_lib/utils/logger.py b/maas_lib/utils/logger.py new file mode 100644 index 00000000..994bd719 --- /dev/null +++ b/maas_lib/utils/logger.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +from typing import Optional + +init_loggers = {} + + +def get_logger(log_file: Optional[str] = None, + log_level: int = logging.INFO, + file_mode: str = 'w'): + """ Get logging logger + + Args: + log_file: Log filename, if specified, file handler will be added to + logger + log_level: Logging level. + file_mode: Specifies the mode to open the file, if filename is + specified (if filemode is unspecified, it defaults to 'w'). + """ + logger_name = __name__.split('.')[0] + logger = logging.getLogger(logger_name) + + if logger_name in init_loggers: + return logger + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + # TODO @wenmeng.zwm add logger setting for distributed environment + if log_file is not None: + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + logger.setLevel(log_level) + init_loggers[logger_name] = True + + return logger diff --git a/maas_lib/utils/pymod.py b/maas_lib/utils/pymod.py new file mode 100644 index 00000000..4f717480 --- /dev/null +++ b/maas_lib/utils/pymod.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import ast +import os +import os.path as osp +import sys +import types +from importlib import import_module + +from maas_lib.utils.logger import get_logger + +logger = get_logger() + + +def import_modules_from_file(py_file: str): + """ Import module from a certrain file + + Args: + py_file: path to a python file to be imported + + Return: + + """ + dirname, basefile = os.path.split(py_file) + if dirname == '': + dirname == './' + module_name = osp.splitext(basefile)[0] + sys.path.insert(0, dirname) + validate_py_syntax(py_file) + mod = import_module(module_name) + sys.path.pop(0) + return module_name, mod + + +def import_modules(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + logger.warning(f'{imp} failed to import and is ignored.') + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def validate_py_syntax(filename): + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') diff --git a/maas_lib/utils/registry.py b/maas_lib/utils/registry.py new file mode 100644 index 00000000..67c4f3c8 --- /dev/null +++ b/maas_lib/utils/registry.py @@ -0,0 +1,183 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect + +from maas_lib.utils.logger import get_logger + +default_group = 'default' +logger = get_logger() + + +class Registry(object): + """ Registry which support registering modules and group them by a keyname + + If group name is not provided, modules will be registered to default group. + """ + + def __init__(self, name: str): + self._name = name + self._modules = dict() + + def __repr__(self): + format_str = self.__class__.__name__ + f'({self._name})\n' + for group_name, group in self._modules.items(): + format_str += f'group_name={group_name}, '\ + f'modules={list(group.keys())}\n' + + return format_str + + @property + def name(self): + return self._name + + @property + def modules(self): + return self._modules + + def list(self): + """ logging the list of module in current registry + """ + for group_name, group in self._modules.items(): + logger.info(f'group_name={group_name}') + for m in group.keys(): + logger.info(f'\t{m}') + logger.info('') + + def get(self, module_key, group_key=default_group): + if group_key not in self._modules: + return None + else: + return self._modules[group_key].get(module_key, None) + + def _register_module(self, + group_key=default_group, + module_name=None, + module_cls=None): + assert isinstance(group_key, + str), 'group_key is required and must be str' + if group_key not in self._modules: + self._modules[group_key] = dict() + + if not inspect.isclass(module_cls): + raise TypeError(f'module is not a class type: {type(module_cls)}') + + if module_name is None: + module_name = module_cls.__name__ + + if module_name in self._modules[group_key]: + raise KeyError(f'{module_name} is already registered in' + f'{self._name}[{group_key}]') + + self._modules[group_key][module_name] = module_cls + + def register_module(self, + group_key: str = default_group, + module_name: str = None, + module_cls: type = None): + """ Register module + + Example: + >>> models = Registry('models') + >>> @models.register_module('image-classification', 'SwinT') + >>> class SwinTransformer: + >>> pass + + >>> @models.register_module('SwinDefault') + >>> class SwinTransformerDefaultGroup: + >>> pass + + Args: + group_key: Group name of which module will be registered, + default group name is 'default' + module_name: Module name + module_cls: Module class object + + """ + if not (module_name is None or isinstance(module_name, str)): + raise TypeError(f'module_name must be either of None, str,' + f'got {type(module_name)}') + + if module_cls is not None: + self._register_module( + group_key=group_key, + module_name=module_name, + module_cls=module_cls) + return module_cls + + # if module_cls is None, should return a dectorator function + def _register(module_cls): + self._register_module( + group_key=group_key, + module_name=module_name, + module_cls=module_cls) + return module_cls + + return _register + + +def build_from_cfg(cfg, + registry: Registry, + group_key: str = default_group, + default_args: dict = None) -> object: + """Build a module from config dict when it is a class configuration, or + call a function from config dict when it is a function configuration. + + Example: + >>> models = Registry('models') + >>> @models.register_module('image-classification', 'SwinT') + >>> class SwinTransformer: + >>> pass + >>> swint = build_from_cfg(dict(type='SwinT'), MODELS, + >>> 'image-classification') + >>> # Returns an instantiated object + >>> + >>> @MODELS.register_module() + >>> def swin_transformer(): + >>> pass + >>> = build_from_cfg(dict(type='swin_transformer'), MODELS) + >>> # Return a result of the calling function + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + group_key (str, optional): The name of registry group from which + module should be searched. + default_args (dict, optional): Default initialization arguments. + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an maas_lib.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type, group_key=group_key) + if obj_cls is None: + raise KeyError(f'{obj_type} is not in the {registry.name}' + f'registry group {group_key}') + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..c6e294ba --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +-r requirements/runtime.txt diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 00000000..25373976 --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,6 @@ +docutils==0.16.0 +recommonmark +sphinx==4.0.2 +sphinx-copybutton +sphinx_markdown_tables +sphinx_rtd_theme==0.5.2 diff --git a/requirements/runtime.txt b/requirements/runtime.txt new file mode 100644 index 00000000..4bbe90e9 --- /dev/null +++ b/requirements/runtime.txt @@ -0,0 +1,5 @@ +addict +numpy +pyyaml +requests +yapf diff --git a/requirements/tests.txt b/requirements/tests.txt new file mode 100644 index 00000000..e73858ca --- /dev/null +++ b/requirements/tests.txt @@ -0,0 +1,5 @@ +expecttest +flake8 +isort==4.3.21 +pre-commit +yapf==0.30.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..6ec3e74b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,24 @@ +[isort] +line_length = 79 +multi_line_output = 0 +known_standard_library = setuptools +known_first_party = maas_lib +known_third_party = json,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[yapf] +BASED_ON_STYLE = pep8 +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true + +[codespell] +skip = *.ipynb +quiet-level = 3 +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids + +[flake8] +select = B,C,E,F,P,T4,W,B9 +max-line-length = 120 +ignore = F401 +exclude = docs/src,*.pyi,.git diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fileio/__init__.py b/tests/fileio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fileio/test_file.py b/tests/fileio/test_file.py new file mode 100644 index 00000000..9f83f02c --- /dev/null +++ b/tests/fileio/test_file.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest + +from requests import HTTPError + +from maas_lib.fileio.file import File, HTTPStorage, LocalStorage + + +class FileTest(unittest.TestCase): + + def test_local_storage(self): + storage = LocalStorage() + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + binary_content = b'12345' + storage.write(binary_content, temp_name) + self.assertEqual(binary_content, storage.read(temp_name)) + + content = '12345' + storage.write_text(content, temp_name) + self.assertEqual(content, storage.read_text(temp_name)) + + os.remove(temp_name) + + def test_http_storage(self): + storage = HTTPStorage() + url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \ + '/data/test/data.txt' + content = 'this is test data' + self.assertEqual(content.encode('utf8'), storage.read(url)) + self.assertEqual(content, storage.read_text(url)) + + with storage.as_local_path(url) as local_file: + with open(local_file, 'r') as infile: + self.assertEqual(content, infile.read()) + + with self.assertRaises(NotImplementedError): + storage.write('dfad', url) + + with self.assertRaises(HTTPError): + storage.read(url + 'df') + + def test_file(self): + url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com'\ + '/data/test/data.txt' + content = 'this is test data' + self.assertEqual(content.encode('utf8'), File.read(url)) + + with File.as_local_path(url) as local_file: + with open(local_file, 'r') as infile: + self.assertEqual(content, infile.read()) + + with self.assertRaises(NotImplementedError): + File.write('dfad', url) + + with self.assertRaises(HTTPError): + File.read(url + 'df') + + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + binary_content = b'12345' + File.write(binary_content, temp_name) + self.assertEqual(binary_content, File.read(temp_name)) + os.remove(temp_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fileio/test_io.py b/tests/fileio/test_io.py new file mode 100644 index 00000000..1e202e5b --- /dev/null +++ b/tests/fileio/test_io.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import tempfile +import unittest + +from maas_lib.fileio.io import dump, dumps, load + + +class FileIOTest(unittest.TestCase): + + def test_format(self, format='json'): + obj = [1, 2, 3, 'str', {'model': 'resnet'}] + result_str = dumps(obj, format) + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + '.' + format + dump(obj, temp_name) + obj_load = load(temp_name) + self.assertEqual(obj_load, obj) + with open(temp_name, 'r') as infile: + self.assertEqual(result_str, infile.read()) + + with self.assertRaises(TypeError): + obj_load = load(temp_name + 's') + + with self.assertRaises(TypeError): + dump(obj, temp_name + 's') + + def test_yaml(self): + self.test_format('yaml') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run.py b/tests/run.py new file mode 100644 index 00000000..25404d7a --- /dev/null +++ b/tests/run.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse +import os +import sys +import unittest +from fnmatch import fnmatch + + +def gather_test_cases(test_dir, pattern, list_tests): + case_list = [] + for dirpath, dirnames, filenames in os.walk(test_dir): + for file in filenames: + if fnmatch(file, pattern): + case_list.append(file) + + test_suite = unittest.TestSuite() + + for case in case_list: + test_case = unittest.defaultTestLoader.discover( + start_dir=test_dir, pattern=case) + test_suite.addTest(test_case) + if hasattr(test_case, '__iter__'): + for subcase in test_case: + if list_tests: + print(subcase) + else: + if list_tests: + print(test_case) + return test_suite + + +def main(args): + runner = unittest.TextTestRunner() + test_suite = gather_test_cases( + os.path.abspath(args.test_dir), args.pattern, args.list_tests) + if not args.list_tests: + result = runner.run(test_suite) + if len(result.failures) > 0: + sys.exit(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('test runner') + parser.add_argument( + '--list_tests', action='store_true', help='list all tests') + parser.add_argument( + '--pattern', default='test_*.py', help='test file pattern') + parser.add_argument( + '--test_dir', default='tests', help='directory to be tested') + args = parser.parse_args() + main(args) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py new file mode 100644 index 00000000..31d51311 --- /dev/null +++ b/tests/utils/test_config.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import os.path as osp +import tempfile +import unittest +from pathlib import Path + +from maas_lib.fileio import dump, load +from maas_lib.utils.config import Config + +obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} + + +class ConfigTest(unittest.TestCase): + + def test_json(self): + config_file = 'configs/examples/config.json' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_yaml(self): + config_file = 'configs/examples/config.yaml' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_py(self): + config_file = 'configs/examples/config.py' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_dump(self): + config_file = 'configs/examples/config.py' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + pretty_text = 'a = 1\n' + pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n" + + json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}' + yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n' + with tempfile.NamedTemporaryFile(suffix='.json') as ofile: + self.assertEqual(pretty_text, cfg.dump()) + cfg.dump(ofile.name) + with open(ofile.name, 'r') as infile: + self.assertEqual(json_str, infile.read()) + + with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: + cfg.dump(ofile.name) + with open(ofile.name, 'r') as infile: + self.assertEqual(yaml_str, infile.read()) + + def test_to_dict(self): + config_file = 'configs/examples/config.json' + cfg = Config.from_file(config_file) + d = cfg.to_dict() + print(d) + self.assertTrue(isinstance(d, dict)) + + def test_to_args(self): + + def parse_fn(args): + parser = argparse.ArgumentParser(prog='PROG') + parser.add_argument('--model-dir', default='') + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--optimizer', default='') + parser.add_argument('--weight-decay', type=float, default=1e-7) + parser.add_argument( + '--save-checkpoint-epochs', type=int, default=30) + return parser.parse_args(args) + + cfg = Config.from_file('configs/examples/plain_args.yaml') + args = cfg.to_args(parse_fn) + + self.assertEqual(args.model_dir, 'path/to/model') + self.assertAlmostEqual(args.lr, 0.01) + self.assertAlmostEqual(args.weight_decay, 1e-6) + self.assertEqual(args.optimizer, 'Adam') + self.assertEqual(args.save_checkpoint_epochs, 20) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_registry.py b/tests/utils/test_registry.py new file mode 100644 index 00000000..c536b145 --- /dev/null +++ b/tests/utils/test_registry.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_lib.utils.constant import Tasks +from maas_lib.utils.registry import Registry, build_from_cfg, default_group + + +class RegistryTest(unittest.TestCase): + + def test_register_class_no_task(self): + MODELS = Registry('models') + self.assertTrue(MODELS.name == 'models') + self.assertTrue(MODELS.modules == {}) + self.assertEqual(len(MODELS.modules), 0) + + @MODELS.register_module(module_name='cls-resnet') + class ResNetForCls(object): + pass + + self.assertTrue(default_group in MODELS.modules) + self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls) + + def test_register_class_with_task(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classfication, 'SwinT') + class SwinTForCls(object): + pass + + self.assertTrue(Tasks.image_classfication in MODELS.modules) + self.assertTrue( + MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls) + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + self.assertTrue(Tasks.sentiment_analysis in MODELS.modules) + self.assertTrue( + MODELS.get('Bert', Tasks.sentiment_analysis) is + BertForSentimentAnalysis) + + @MODELS.register_module(Tasks.object_detection) + class DETR(object): + pass + + self.assertTrue(Tasks.object_detection in MODELS.modules) + self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) + + self.assertEqual(len(MODELS.modules), 3) + + def test_list(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classfication, 'SwinT') + class SwinTForCls(object): + pass + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + MODELS.list() + print(MODELS) + + def test_build(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classfication, 'SwinT') + class SwinTForCls(object): + pass + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + cfg = dict(type='SwinT') + model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) + self.assertTrue(isinstance(model, SwinTForCls)) + + cfg = dict(type='Bert') + model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis) + self.assertTrue(isinstance(model, BertForSentimentAnalysis)) + + with self.assertRaises(KeyError): + cfg = dict(type='Bert') + model = build_from_cfg(cfg, MODELS, Tasks.image_classfication) + + +if __name__ == '__main__': + unittest.main()