| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: https://gitlab.com/pycqa/flake8.git | - repo: https://gitlab.com/pycqa/flake8.git | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: /home/admin/pre-commit/flake8 | - repo: /home/admin/pre-commit/flake8 | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||||
| size 38954 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||||
| size 93676 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||||
| size 472012 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||||
| size 71244 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||||
| size 181964 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||||
| size 112078 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||||
| size 200556 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||||
| size 162636 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||||
| size 160204 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 | |||||
| size 151572 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||||
| size 62231 | |||||
| oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b | |||||
| size 61741 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||||
| size 62235 | |||||
| oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 | |||||
| size 61745 | |||||
| @@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| def generate_dummy_inputs(self, | def generate_dummy_inputs(self, | ||||
| shape: Tuple = None, | shape: Tuple = None, | ||||
| pair: bool = False, | |||||
| **kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | """Generate dummy inputs for model exportation to onnx or other formats by tracing. | ||||
| @param shape: A tuple of input shape which should have at most two dimensions. | @param shape: A tuple of input shape which should have at most two dimensions. | ||||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | ||||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | ||||
| @param pair: Generate sentence pairs or single sentences for dummy inputs. | |||||
| @return: Dummy inputs. | @return: Dummy inputs. | ||||
| """ | """ | ||||
| @@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| **sequence_length | **sequence_length | ||||
| }) | }) | ||||
| preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | ||||
| if preprocessor.pair: | |||||
| if pair: | |||||
| first_sequence = preprocessor.tokenizer.unk_token | first_sequence = preprocessor.tokenizer.unk_token | ||||
| second_sequence = preprocessor.tokenizer.unk_token | second_sequence = preprocessor.tokenizer.unk_token | ||||
| else: | else: | ||||
| @@ -1,8 +1,11 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| # yapf: disable | |||||
| import datetime | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| @@ -16,17 +19,25 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | ||||
| API_RESPONSE_FIELD_MESSAGE, | API_RESPONSE_FIELD_MESSAGE, | ||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH) | |||||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||||
| ModelVisibility) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||||
| NotLoginException, RequestError, | |||||
| datahub_raise_on_error, | |||||
| handle_http_post_error, | |||||
| handle_http_response, is_ok, raise_on_error) | |||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.utils.utils import (get_endpoint, | |||||
| model_id_to_group_owner_name) | |||||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION, | DEFAULT_MODEL_REVISION, | ||||
| DatasetFormations, DatasetMetaFormats, | DatasetFormations, DatasetMetaFormats, | ||||
| DownloadMode) | |||||
| DownloadMode, ModelFile) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .errors import (InvalidParameter, NotExistError, RequestError, | |||||
| datahub_raise_on_error, handle_http_post_error, | |||||
| handle_http_response, is_ok, raise_on_error) | |||||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||||
| # yapf: enable | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -169,11 +180,106 @@ class HubApi: | |||||
| else: | else: | ||||
| r.raise_for_status() | r.raise_for_status() | ||||
| def list_model(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List model in owner or group. | |||||
| def push_model(self, | |||||
| model_id: str, | |||||
| model_dir: str, | |||||
| visibility: int = ModelVisibility.PUBLIC, | |||||
| license: str = Licenses.APACHE_V2, | |||||
| chinese_name: Optional[str] = None, | |||||
| commit_message: Optional[str] = 'upload model', | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||||
| """ | |||||
| Upload model from a given directory to given repository. A valid model directory | |||||
| must contain a configuration.json file. | |||||
| This function upload the files in given directory to given repository. If the | |||||
| given repository is not exists in remote, it will automatically create it with | |||||
| given visibility, license and chinese_name parameters. If the revision is also | |||||
| not exists in remote repository, it will create a new branch for it. | |||||
| This function must be called before calling HubApi's login with a valid token | |||||
| which can be obtained from ModelScope's website. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model id to be uploaded, caller must have write permission for it. | |||||
| model_dir(`str`): | |||||
| The Absolute Path of the finetune result. | |||||
| visibility(`int`, defaults to `0`): | |||||
| Visibility of the new created model(1-private, 5-public). If the model is | |||||
| not exists in ModelScope, this function will create a new model with this | |||||
| visibility and this parameter is required. You can ignore this parameter | |||||
| if you make sure the model's existence. | |||||
| license(`str`, defaults to `None`): | |||||
| License of the new created model(see License). If the model is not exists | |||||
| in ModelScope, this function will create a new model with this license | |||||
| and this parameter is required. You can ignore this parameter if you | |||||
| make sure the model's existence. | |||||
| chinese_name(`str`, *optional*, defaults to `None`): | |||||
| chinese name of the new created model. | |||||
| commit_message(`str`, *optional*, defaults to `None`): | |||||
| commit message of the push request. | |||||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||||
| which branch to push. If the branch is not exists, It will create a new | |||||
| branch and push to it. | |||||
| """ | |||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id cannot be empty!') | |||||
| if model_dir is None: | |||||
| raise InvalidParameter('model_dir cannot be empty!') | |||||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||||
| raise InvalidParameter('model_dir must be a valid directory.') | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| if not os.path.exists(cfg_file): | |||||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException('Must login before upload!') | |||||
| files_to_save = os.listdir(model_dir) | |||||
| try: | |||||
| self.get_model(model_id=model_id) | |||||
| except Exception: | |||||
| if visibility is None or license is None: | |||||
| raise InvalidParameter( | |||||
| 'visibility and license cannot be empty if want to create new repo' | |||||
| ) | |||||
| logger.info('Create new model %s' % model_id) | |||||
| self.create_model( | |||||
| model_id=model_id, | |||||
| visibility=visibility, | |||||
| license=license, | |||||
| chinese_name=chinese_name) | |||||
| tmp_dir = tempfile.mkdtemp() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| try: | |||||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||||
| if revision not in branches: | |||||
| logger.info('Create new branch %s' % revision) | |||||
| git_wrapper.new_branch(tmp_dir, revision) | |||||
| git_wrapper.checkout(tmp_dir, revision) | |||||
| for f in files_to_save: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(model_dir, f) | |||||
| if os.path.isdir(src): | |||||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||||
| else: | |||||
| shutil.copy(src, tmp_dir) | |||||
| if not commit_message: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||||
| model_id, date) | |||||
| repo.push(commit_message=commit_message, branch=revision) | |||||
| except Exception: | |||||
| raise | |||||
| finally: | |||||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||||
| def list_models(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List models in owner or group. | |||||
| Args: | Args: | ||||
| owner_or_group(`str`): owner or group. | owner_or_group(`str`): owner or group. | ||||
| @@ -390,11 +496,13 @@ class HubApi: | |||||
| return resp['Data'] | return resp['Data'] | ||||
| def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | ||||
| is_recursive, is_filter_dir, revision, | |||||
| cookies): | |||||
| is_recursive, is_filter_dir, revision): | |||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | ||||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | ||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies: | |||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||||
| resp = requests.get(url=url, cookies=cookies) | resp = requests.get(url=url, cookies=cookies) | ||||
| resp = resp.json() | resp = resp.json() | ||||
| @@ -11,13 +11,12 @@ from typing import Dict, Optional, Union | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import requests | import requests | ||||
| from filelock import FileLock | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import FileDownloadError, NotExistError | from .errors import FileDownloadError, NotExistError | ||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| @@ -1,13 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import re | |||||
| import subprocess | import subprocess | ||||
| from typing import List | from typing import List | ||||
| from xmlrpc.client import Boolean | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from .errors import GitError | from .errors import GitError | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -132,6 +129,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| return response | return response | ||||
| def add_user_info(self, repo_base_dir, repo_name): | def add_user_info(self, repo_base_dir, repo_name): | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| user_name, user_email = ModelScopeConfig.get_user_info() | user_name, user_email = ModelScopeConfig.get_user_info() | ||||
| if user_name and user_email: | if user_name and user_email: | ||||
| # config user.name and user.email if exist | # config user.name and user.email if exist | ||||
| @@ -184,8 +182,11 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| info = [ | info = [ | ||||
| line.strip() | line.strip() | ||||
| for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | ||||
| ][1:] | |||||
| return ['/'.join(line.split('/')[1:]) for line in info] | |||||
| ] | |||||
| if len(info) == 1: | |||||
| return ['/'.join(info[0].split('/')[1:])] | |||||
| else: | |||||
| return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||||
| def pull(self, repo_dir: str): | def pull(self, repo_dir: str): | ||||
| cmds = ['-C', repo_dir, 'pull'] | cmds = ['-C', repo_dir, 'pull'] | ||||
| @@ -7,7 +7,6 @@ from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION) | DEFAULT_MODEL_REVISION) | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from .git import GitCommandWrapper | from .git import GitCommandWrapper | ||||
| from .utils.utils import get_endpoint | from .utils.utils import get_endpoint | ||||
| @@ -47,6 +46,7 @@ class Repository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -166,7 +166,7 @@ class DatasetRepository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| self.revision = revision | self.revision = revision | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -5,9 +5,9 @@ import tempfile | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import NotExistError | from .errors import NotExistError | ||||
| from .file_download import (get_file_download_url, http_get_file, | from .file_download import (get_file_download_url, http_get_file, | ||||
| @@ -1,117 +0,0 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import datetime | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import uuid | |||||
| from typing import Dict, Optional | |||||
| from uuid import uuid4 | |||||
| from filelock import FileLock | |||||
| from modelscope import __version__ | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.hub.errors import InvalidParameter, NotLoginException | |||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def upload_folder(model_id: str, | |||||
| model_dir: str, | |||||
| visibility: int = 0, | |||||
| license: str = None, | |||||
| chinese_name: Optional[str] = None, | |||||
| commit_message: Optional[str] = None, | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||||
| """ | |||||
| Upload model from a given directory to given repository. A valid model directory | |||||
| must contain a configuration.json file. | |||||
| This function upload the files in given directory to given repository. If the | |||||
| given repository is not exists in remote, it will automatically create it with | |||||
| given visibility, license and chinese_name parameters. If the revision is also | |||||
| not exists in remote repository, it will create a new branch for it. | |||||
| This function must be called before calling HubApi's login with a valid token | |||||
| which can be obtained from ModelScope's website. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model id to be uploaded, caller must have write permission for it. | |||||
| model_dir(`str`): | |||||
| The Absolute Path of the finetune result. | |||||
| visibility(`int`, defaults to `0`): | |||||
| Visibility of the new created model(1-private, 5-public). If the model is | |||||
| not exists in ModelScope, this function will create a new model with this | |||||
| visibility and this parameter is required. You can ignore this parameter | |||||
| if you make sure the model's existence. | |||||
| license(`str`, defaults to `None`): | |||||
| License of the new created model(see License). If the model is not exists | |||||
| in ModelScope, this function will create a new model with this license | |||||
| and this parameter is required. You can ignore this parameter if you | |||||
| make sure the model's existence. | |||||
| chinese_name(`str`, *optional*, defaults to `None`): | |||||
| chinese name of the new created model. | |||||
| commit_message(`str`, *optional*, defaults to `None`): | |||||
| commit message of the push request. | |||||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||||
| which branch to push. If the branch is not exists, It will create a new | |||||
| branch and push to it. | |||||
| """ | |||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id cannot be empty!') | |||||
| if model_dir is None: | |||||
| raise InvalidParameter('model_dir cannot be empty!') | |||||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||||
| raise InvalidParameter('model_dir must be a valid directory.') | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| if not os.path.exists(cfg_file): | |||||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException('Must login before upload!') | |||||
| files_to_save = os.listdir(model_dir) | |||||
| api = HubApi() | |||||
| try: | |||||
| api.get_model(model_id=model_id) | |||||
| except Exception: | |||||
| if visibility is None or license is None: | |||||
| raise InvalidParameter( | |||||
| 'visibility and license cannot be empty if want to create new repo' | |||||
| ) | |||||
| logger.info('Create new model %s' % model_id) | |||||
| api.create_model( | |||||
| model_id=model_id, | |||||
| visibility=visibility, | |||||
| license=license, | |||||
| chinese_name=chinese_name) | |||||
| tmp_dir = tempfile.mkdtemp() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| try: | |||||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||||
| if revision not in branches: | |||||
| logger.info('Create new branch %s' % revision) | |||||
| git_wrapper.new_branch(tmp_dir, revision) | |||||
| git_wrapper.checkout(tmp_dir, revision) | |||||
| for f in files_to_save: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(model_dir, f) | |||||
| if os.path.isdir(src): | |||||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||||
| else: | |||||
| shutil.copy(src, tmp_dir) | |||||
| if not commit_message: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||||
| model_id, date) | |||||
| repo.push(commit_message=commit_message, branch=revision) | |||||
| except Exception: | |||||
| raise | |||||
| finally: | |||||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||||
| @@ -9,7 +9,9 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # tinynas models | |||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||||
| # vision models | # vision models | ||||
| detection = 'detection' | detection = 'detection' | ||||
| @@ -454,9 +456,9 @@ class Datasets(object): | |||||
| """ Names for different datasets. | """ Names for different datasets. | ||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||||
| Face2dKeypointsDataset = 'FaceKeypointDataset' | |||||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | ||||
| HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||||
| HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' | ||||
| @@ -32,6 +32,7 @@ task_default_metrics = { | |||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.token_classification: [Metrics.token_cls_metric], | Tasks.token_classification: [Metrics.token_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.text_classification: [Metrics.seq_cls_metric], | |||||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | Tasks.image_denoising: [Metrics.image_denoise_metric], | ||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | ||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| @@ -2,6 +2,7 @@ | |||||
| import os | import os | ||||
| import pickle as pkl | import pickle as pkl | ||||
| from threading import Lock | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| @@ -27,6 +28,7 @@ class Voice: | |||||
| self.__am_config = AttrDict(**am_config) | self.__am_config = AttrDict(**am_config) | ||||
| self.__voc_config = AttrDict(**voc_config) | self.__voc_config = AttrDict(**voc_config) | ||||
| self.__model_loaded = False | self.__model_loaded = False | ||||
| self.__lock = Lock() | |||||
| if 'am' not in self.__am_config: | if 'am' not in self.__am_config: | ||||
| raise TtsModelConfigurationException( | raise TtsModelConfigurationException( | ||||
| 'modelscope error: am configuration invalid') | 'modelscope error: am configuration invalid') | ||||
| @@ -71,34 +73,35 @@ class Voice: | |||||
| self.__generator.remove_weight_norm() | self.__generator.remove_weight_norm() | ||||
| def __am_forward(self, symbol_seq): | def __am_forward(self, symbol_seq): | ||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||||
| self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| with self.__lock: | |||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy( | |||||
| inputs_feat_lst[2]).long().to(self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| def __vocoder_forward(self, melspec): | def __vocoder_forward(self, melspec): | ||||
| dim0 = list(melspec.shape)[-1] | dim0 = list(melspec.shape)[-1] | ||||
| @@ -118,14 +121,15 @@ class Voice: | |||||
| return audio | return audio | ||||
| def forward(self, symbol_seq): | def forward(self, symbol_seq): | ||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| if torch.cuda.is_available(): | |||||
| with self.__lock: | |||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | torch.manual_seed(self.__am_config.seed) | ||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| if torch.cuda.is_available(): | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| return self.__vocoder_forward(self.__am_forward(symbol_seq)) | return self.__vocoder_forward(self.__am_forward(symbol_seq)) | ||||
| @@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel): | |||||
| """ | """ | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if self.device_id == -1: | if self.device_id == -1: | ||||
| output = self.model(image) | |||||
| output = self.model(image, [text]) | |||||
| else: | else: | ||||
| device = torch.device('cuda', self.device_id) | device = torch.device('cuda', self.device_id) | ||||
| output = self.model(image.to(device), [text]) | output = self.model(image.to(device), [text]) | ||||
| @@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .tinynas_detector import Tinynas_detector | from .tinynas_detector import Tinynas_detector | ||||
| from .tinynas_damoyolo import DamoYolo | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'tinynas_detector': ['TinynasDetector'], | 'tinynas_detector': ['TinynasDetector'], | ||||
| 'tinynas_damoyolo': ['DamoYolo'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -4,6 +4,7 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from modelscope.utils.file_utils import read_file | |||||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | from ..core.base_ops import Focus, SPPBottleneck, get_activation | ||||
| from ..core.repvgg_block import RepVggBlock | from ..core.repvgg_block import RepVggBlock | ||||
| @@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): | |||||
| kernel_size, | kernel_size, | ||||
| stride, | stride, | ||||
| force_resproj=False, | force_resproj=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(ResConvK1KX, self).__init__() | super(ResConvK1KX, self).__init__() | ||||
| self.stride = stride | self.stride = stride | ||||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | ||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if not reparam: | |||||
| self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||||
| else: | |||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if act is None: | if act is None: | ||||
| self.activation_function = torch.relu | self.activation_function = torch.relu | ||||
| @@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| stride, | stride, | ||||
| num_blocks, | num_blocks, | ||||
| with_spp=False, | with_spp=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(SuperResConvK1KX, self).__init__() | super(SuperResConvK1KX, self).__init__() | ||||
| if act is None: | if act is None: | ||||
| self.act = torch.relu | self.act = torch.relu | ||||
| @@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| this_kernel_size, | this_kernel_size, | ||||
| this_stride, | this_stride, | ||||
| force_resproj, | force_resproj, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| if block_id == 0 and with_spp: | if block_id == 0 and with_spp: | ||||
| self.block_list.append( | self.block_list.append( | ||||
| @@ -248,7 +255,8 @@ class TinyNAS(nn.Module): | |||||
| with_spp=False, | with_spp=False, | ||||
| use_focus=False, | use_focus=False, | ||||
| need_conv1=True, | need_conv1=True, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(TinyNAS, self).__init__() | super(TinyNAS, self).__init__() | ||||
| assert len(out_indices) == len(out_channels) | assert len(out_indices) == len(out_channels) | ||||
| self.out_indices = out_indices | self.out_indices = out_indices | ||||
| @@ -281,7 +289,8 @@ class TinyNAS(nn.Module): | |||||
| block_info['s'], | block_info['s'], | ||||
| block_info['L'], | block_info['L'], | ||||
| spp, | spp, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| elif the_block_class == 'SuperResConvKXKX': | elif the_block_class == 'SuperResConvKXKX': | ||||
| spp = with_spp if idx == len(structure_info) - 1 else False | spp = with_spp if idx == len(structure_info) - 1 else False | ||||
| @@ -325,8 +334,8 @@ class TinyNAS(nn.Module): | |||||
| def load_tinynas_net(backbone_cfg): | def load_tinynas_net(backbone_cfg): | ||||
| # load masternet model to path | # load masternet model to path | ||||
| import ast | import ast | ||||
| struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||||
| struct_info = ast.literal_eval(struct_str) | struct_info = ast.literal_eval(struct_str) | ||||
| for layer in struct_info: | for layer in struct_info: | ||||
| if 'nbitsA' in layer: | if 'nbitsA' in layer: | ||||
| @@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): | |||||
| use_focus=backbone_cfg.use_focus, | use_focus=backbone_cfg.use_focus, | ||||
| act=backbone_cfg.act, | act=backbone_cfg.act, | ||||
| need_conv1=backbone_cfg.need_conv1, | need_conv1=backbone_cfg.need_conv1, | ||||
| ) | |||||
| reparam=backbone_cfg.reparam) | |||||
| return model | return model | ||||
| @@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| config_path = osp.join(model_dir, 'airdet_s.py') | |||||
| config_path = osp.join(model_dir, self.config_name) | |||||
| config = parse_config(config_path) | config = parse_config(config_path) | ||||
| self.cfg = config | self.cfg = config | ||||
| model_path = osp.join(model_dir, config.model.name) | model_path = osp.join(model_dir, config.model.name) | ||||
| @@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): | |||||
| self.conf_thre = config.model.head.nms_conf_thre | self.conf_thre = config.model.head.nms_conf_thre | ||||
| self.nms_thre = config.model.head.nms_iou_thre | self.nms_thre = config.model.head.nms_iou_thre | ||||
| if self.cfg.model.backbone.name == 'TinyNAS': | |||||
| self.cfg.model.backbone.structure_file = osp.join( | |||||
| model_dir, self.cfg.model.backbone.structure_file) | |||||
| self.backbone = build_backbone(self.cfg.model.backbone) | self.backbone = build_backbone(self.cfg.model.backbone) | ||||
| self.neck = build_neck(self.cfg.model.neck) | self.neck = build_neck(self.cfg.model.neck) | ||||
| self.head = build_head(self.cfg.model.head) | self.head = build_head(self.cfg.model.head) | ||||
| @@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): | |||||
| simOTA_iou_weight=3.0, | simOTA_iou_weight=3.0, | ||||
| octbase=8, | octbase=8, | ||||
| simlqe=False, | simlqe=False, | ||||
| use_lqe=True, | |||||
| **kwargs): | **kwargs): | ||||
| self.simlqe = simlqe | self.simlqe = simlqe | ||||
| self.num_classes = num_classes | self.num_classes = num_classes | ||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| self.strides = strides | self.strides = strides | ||||
| self.use_lqe = use_lqe | |||||
| self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | ||||
| else [feat_channels] * len(self.strides) | else [feat_channels] * len(self.strides) | ||||
| @@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): | |||||
| groups=self.conv_groups, | groups=self.conv_groups, | ||||
| norm=self.norm, | norm=self.norm, | ||||
| act=self.act)) | act=self.act)) | ||||
| if not self.simlqe: | |||||
| conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) | |||||
| ] | |||||
| else: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| else: | else: | ||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| reg_conf = None | |||||
| return cls_convs, reg_convs, reg_conf | return cls_convs, reg_convs, reg_conf | ||||
| @@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): | |||||
| N, C, H, W = bbox_pred.size() | N, C, H, W = bbox_pred.size() | ||||
| prob = F.softmax( | prob = F.softmax( | ||||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | ||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, | |||||
| prob_topk.mean(dim=2, keepdim=True)], | |||||
| dim=2) | |||||
| else: | |||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| else: | else: | ||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| else: | else: | ||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() | |||||
| flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | ||||
| flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | ||||
| @@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| self, | self, | ||||
| depth=1.0, | depth=1.0, | ||||
| width=1.0, | width=1.0, | ||||
| in_features=[2, 3, 4], | |||||
| in_channels=[256, 512, 1024], | in_channels=[256, 512, 1024], | ||||
| out_channels=[256, 512, 1024], | out_channels=[256, 512, 1024], | ||||
| depthwise=False, | depthwise=False, | ||||
| @@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| block_name='BasicBlock', | block_name='BasicBlock', | ||||
| ): | ): | ||||
| super().__init__() | super().__init__() | ||||
| self.in_features = in_features | |||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| Conv = DWConv if depthwise else BaseConv | Conv = DWConv if depthwise else BaseConv | ||||
| @@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): | |||||
| """ | """ | ||||
| # backbone | # backbone | ||||
| features = [out_features[f] for f in self.in_features] | |||||
| [x2, x1, x0] = features | |||||
| [x2, x1, x0] = out_features | |||||
| # node x3 | # node x3 | ||||
| x13 = self.bu_conv13(x1) | x13 = self.bu_conv13(x1) | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .detector import SingleStageDetector | |||||
| @MODELS.register_module( | |||||
| Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) | |||||
| class DamoYolo(SingleStageDetector): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| self.config_name = 'damoyolo_s.py' | |||||
| super(DamoYolo, self).__init__(model_dir, *args, **kwargs) | |||||
| @@ -12,5 +12,5 @@ from .detector import SingleStageDetector | |||||
| class TinynasDetector(SingleStageDetector): | class TinynasDetector(SingleStageDetector): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| self.config_name = 'airdet_s.py' | |||||
| super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | ||||
| @@ -15,7 +15,6 @@ | |||||
| """PyTorch BERT model. """ | """PyTorch BERT model. """ | ||||
| import math | import math | ||||
| import os | |||||
| import warnings | import warnings | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Optional, Tuple | from typing import Optional, Tuple | ||||
| @@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, | |||||
| find_pruneable_heads_and_indices, | find_pruneable_heads_and_indices, | ||||
| prune_linear_layer) | prune_linear_layer) | ||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .configuration_bert import BertConfig | from .configuration_bert import BertConfig | ||||
| @@ -50,81 +48,6 @@ logger = get_logger(__name__) | |||||
| _CONFIG_FOR_DOC = 'BertConfig' | _CONFIG_FOR_DOC = 'BertConfig' | ||||
| def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |||||
| """Load tf checkpoints in a pytorch model.""" | |||||
| try: | |||||
| import re | |||||
| import numpy as np | |||||
| import tensorflow as tf | |||||
| except ImportError: | |||||
| logger.error( | |||||
| 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' | |||||
| 'https://www.tensorflow.org/install/ for installation instructions.' | |||||
| ) | |||||
| raise | |||||
| tf_path = os.path.abspath(tf_checkpoint_path) | |||||
| logger.info(f'Converting TensorFlow checkpoint from {tf_path}') | |||||
| # Load weights from TF model | |||||
| init_vars = tf.train.list_variables(tf_path) | |||||
| names = [] | |||||
| arrays = [] | |||||
| for name, shape in init_vars: | |||||
| logger.info(f'Loading TF weight {name} with shape {shape}') | |||||
| array = tf.train.load_variable(tf_path, name) | |||||
| names.append(name) | |||||
| arrays.append(array) | |||||
| for name, array in zip(names, arrays): | |||||
| name = name.split('/') | |||||
| # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |||||
| # which are not required for using pretrained model | |||||
| if any(n in [ | |||||
| 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', | |||||
| 'AdamWeightDecayOptimizer_1', 'global_step' | |||||
| ] for n in name): | |||||
| logger.info(f"Skipping {'/'.join(name)}") | |||||
| continue | |||||
| pointer = model | |||||
| for m_name in name: | |||||
| if re.fullmatch(r'[A-Za-z]+_\d+', m_name): | |||||
| scope_names = re.split(r'_(\d+)', m_name) | |||||
| else: | |||||
| scope_names = [m_name] | |||||
| if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': | |||||
| pointer = getattr(pointer, 'bias') | |||||
| elif scope_names[0] == 'output_weights': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif scope_names[0] == 'squad': | |||||
| pointer = getattr(pointer, 'classifier') | |||||
| else: | |||||
| try: | |||||
| pointer = getattr(pointer, scope_names[0]) | |||||
| except AttributeError: | |||||
| logger.info(f"Skipping {'/'.join(name)}") | |||||
| continue | |||||
| if len(scope_names) >= 2: | |||||
| num = int(scope_names[1]) | |||||
| pointer = pointer[num] | |||||
| if m_name[-11:] == '_embeddings': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif m_name == 'kernel': | |||||
| array = np.transpose(array) | |||||
| try: | |||||
| if pointer.shape != array.shape: | |||||
| raise ValueError( | |||||
| f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' | |||||
| ) | |||||
| except AssertionError as e: | |||||
| e.args += (pointer.shape, array.shape) | |||||
| raise | |||||
| logger.info(f'Initialize PyTorch weight {name}') | |||||
| pointer.data = torch.from_numpy(array) | |||||
| return model | |||||
| class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
| """Construct the embeddings from word, position and token_type embeddings.""" | """Construct the embeddings from word, position and token_type embeddings.""" | ||||
| @@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): | |||||
| """ | """ | ||||
| config_class = BertConfig | config_class = BertConfig | ||||
| load_tf_weights = load_tf_weights_in_bert | |||||
| base_model_prefix = 'bert' | base_model_prefix = 'bert' | ||||
| supports_gradient_checkpointing = True | supports_gradient_checkpointing = True | ||||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | _keys_to_ignore_on_load_missing = [r'position_ids'] | ||||
| @@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): | |||||
| if self.split_config is not None: | if self.split_config is not None: | ||||
| self._update_data_source(kwargs['data_source']) | self._update_data_source(kwargs['data_source']) | ||||
| def _update_data_root(self, input_dict, data_root): | |||||
| for k, v in input_dict.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| input_dict.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| elif isinstance(v, dict): | |||||
| self._update_data_root(v, data_root) | |||||
| def _update_data_source(self, data_source): | def _update_data_source(self, data_source): | ||||
| data_root = next(iter(self.split_config.values())) | data_root = next(iter(self.split_config.values())) | ||||
| data_root = data_root.rstrip(osp.sep) | data_root = data_root.rstrip(osp.sep) | ||||
| for k, v in data_source.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| data_source.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| self._update_data_root(data_source, data_root) | |||||
| @@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union | |||||
| from datasets.builder import DatasetBuilder | from datasets.builder import DatasetBuilder | ||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams | |||||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | ||||
| @@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||||
| res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | ||||
| """ | """ | ||||
| res = [] | res = [] | ||||
| cookies = hub_api.check_cookies_upload_data(use_cookies=True) | |||||
| objects = hub_api.list_oss_dataset_objects( | objects = hub_api.list_oss_dataset_objects( | ||||
| dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
| namespace=namespace, | namespace=namespace, | ||||
| max_limit=max_limit, | max_limit=max_limit, | ||||
| is_recursive=is_recursive, | is_recursive=is_recursive, | ||||
| is_filter_dir=True, | is_filter_dir=True, | ||||
| revision=version, | |||||
| cookies=cookies) | |||||
| revision=version) | |||||
| for item in objects: | for item in objects: | ||||
| object_key = item.get('Key') | object_key = item.get('Key') | ||||
| @@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict, | |||||
| modelscope_api = HubApi() | modelscope_api = HubApi() | ||||
| objects = list_dataset_objects( | objects = list_dataset_objects( | ||||
| hub_api=modelscope_api, | hub_api=modelscope_api, | ||||
| max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value, | |||||
| max_limit=-1, | |||||
| is_recursive=True, | is_recursive=True, | ||||
| dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
| namespace=namespace, | namespace=namespace, | ||||
| @@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| if isinstance(audio_in, str): | if isinstance(audio_in, str): | ||||
| # load pcm data from url if audio_in is url str | # load pcm data from url if audio_in is url str | ||||
| self.audio_in = load_bytes_from_url(audio_in) | |||||
| self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | elif isinstance(audio_in, bytes): | ||||
| # load pcm data from wav data if audio_in is wave format | # load pcm data from wav data if audio_in is wave format | ||||
| self.audio_in = extract_pcm_from_wav(audio_in) | |||||
| self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||||
| else: | else: | ||||
| self.audio_in = audio_in | self.audio_in = audio_in | ||||
| # set the sample_rate of audio_in if checking_audio_fs is valid | |||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| if recog_type is None or audio_format is None: | if recog_type is None or audio_format is None: | ||||
| self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | ||||
| audio_in=self.audio_in, | audio_in=self.audio_in, | ||||
| recog_type=recog_type, | recog_type=recog_type, | ||||
| audio_format=audio_format) | audio_format=audio_format) | ||||
| if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None: | |||||
| self.audio_fs = asr_utils.sample_rate_checking( | |||||
| if hasattr(asr_utils, 'sample_rate_checking'): | |||||
| checking_audio_fs = asr_utils.sample_rate_checking( | |||||
| self.audio_in, self.audio_format) | self.audio_in, self.audio_format) | ||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| if self.preprocessor is None: | if self.preprocessor is None: | ||||
| self.preprocessor = WavToScp() | self.preprocessor = WavToScp() | ||||
| @@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| logger.info(f"Decoding with {inputs['audio_format']} files ...") | logger.info(f"Decoding with {inputs['audio_format']} files ...") | ||||
| data_cmd: Sequence[Tuple[str, str]] | |||||
| data_cmd: Sequence[Tuple[str, str, str]] | |||||
| if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | ||||
| data_cmd = ['speech', 'sound'] | data_cmd = ['speech', 'sound'] | ||||
| elif inputs['audio_format'] == 'kaldi_ark': | elif inputs['audio_format'] == 'kaldi_ark': | ||||
| @@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| elif inputs['audio_format'] == 'tfrecord': | elif inputs['audio_format'] == 'tfrecord': | ||||
| data_cmd = ['speech', 'tfrecord'] | data_cmd = ['speech', 'tfrecord'] | ||||
| if inputs.__contains__('mvn_file'): | |||||
| data_cmd.append(inputs['mvn_file']) | |||||
| # generate asr inference command | # generate asr inference command | ||||
| cmd = { | cmd = { | ||||
| 'model_type': inputs['model_type'], | 'model_type': inputs['model_type'], | ||||
| @@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| if isinstance(audio_in, str): | if isinstance(audio_in, str): | ||||
| # load pcm data from url if audio_in is url str | # load pcm data from url if audio_in is url str | ||||
| audio_in = load_bytes_from_url(audio_in) | |||||
| audio_in, audio_fs = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | elif isinstance(audio_in, bytes): | ||||
| # load pcm data from wav data if audio_in is wave format | # load pcm data from wav data if audio_in is wave format | ||||
| audio_in = extract_pcm_from_wav(audio_in) | |||||
| audio_in, audio_fs = extract_pcm_from_wav(audio_in) | |||||
| output = self.preprocessor.forward(self.model.forward(), audio_in) | output = self.preprocessor.forward(self.model.forward(), audio_in) | ||||
| output = self.forward(output) | output = self.forward(output) | ||||
| @@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.cv.image_utils import \ | |||||
| show_image_object_detection_auto_result | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): | |||||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | bboxes, scores, labels = self.model.postprocess(inputs['data']) | ||||
| if bboxes is None: | if bboxes is None: | ||||
| return None | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| outputs = { | |||||
| OutputKeys.SCORES: [], | |||||
| OutputKeys.LABELS: [], | |||||
| OutputKeys.BOXES: [] | |||||
| } | |||||
| else: | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| return outputs | return outputs | ||||
| def show_result(self, img_path, result, save_path=None): | |||||
| show_image_object_detection_auto_result(img_path, result, save_path) | |||||
| @@ -133,6 +133,12 @@ class WavToScp(Preprocessor): | |||||
| else: | else: | ||||
| inputs['asr_model_config'] = asr_model_config | inputs['asr_model_config'] = asr_model_config | ||||
| if inputs['model_config'].__contains__('mvn_file'): | |||||
| mvn_file = os.path.join(inputs['model_workspace'], | |||||
| inputs['model_config']['mvn_file']) | |||||
| assert os.path.exists(mvn_file), 'mvn_file does not exist' | |||||
| inputs['mvn_file'] = mvn_file | |||||
| elif inputs['model_type'] == Frameworks.tf: | elif inputs['model_type'] == Frameworks.tf: | ||||
| assert inputs['model_config'].__contains__( | assert inputs['model_config'].__contains__( | ||||
| 'vocab_file'), 'vocab_file does not exist' | 'vocab_file'), 'vocab_file does not exist' | ||||
| @@ -2,7 +2,7 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| import re | import re | ||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| import sentencepiece as spm | import sentencepiece as spm | ||||
| @@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| return isinstance(label, str) or isinstance(label, int) | return isinstance(label, str) or isinstance(label, int) | ||||
| if labels is not None: | if labels is not None: | ||||
| if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
| if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
| and self.label2id is not None: | and self.label2id is not None: | ||||
| output[OutputKeys.LABELS] = [ | output[OutputKeys.LABELS] = [ | ||||
| self.label2id[str(label)] for label in labels | self.label2id[str(label)] for label in labels | ||||
| @@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | ||||
| kwargs['truncation'] = kwargs.get('truncation', True) | kwargs['truncation'] = kwargs.get('truncation', True) | ||||
| kwargs['padding'] = kwargs.get( | |||||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | kwargs['max_length'] = kwargs.pop('sequence_length', 128) | ||||
| super().__init__(model_dir, mode=mode, **kwargs) | super().__init__(model_dir, mode=mode, **kwargs) | ||||
| @@ -1,5 +1,10 @@ | |||||
| import math | import math | ||||
| import os | |||||
| import random | import random | ||||
| import uuid | |||||
| from os.path import exists | |||||
| from tempfile import TemporaryDirectory | |||||
| from urllib.parse import urlparse | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms | |||||
| from decord import VideoReader | from decord import VideoReader | ||||
| from torchvision.transforms import Compose | from torchvision.transforms import Compose | ||||
| from modelscope.hub.file_download import http_get_file | |||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.utils.constant import Fields, ModeKeys | from modelscope.utils.constant import Fields, ModeKeys | ||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| @@ -30,7 +36,22 @@ def ReadVideoData(cfg, | |||||
| Returns: | Returns: | ||||
| data (Tensor): the normalized video clips for model inputs | data (Tensor): the normalized video clips for model inputs | ||||
| """ | """ | ||||
| data = _decode_video(cfg, video_path, num_temporal_views_override) | |||||
| url_parsed = urlparse(video_path) | |||||
| if url_parsed.scheme in ('file', '') and exists( | |||||
| url_parsed.path): # Possibly a local file | |||||
| data = _decode_video(cfg, video_path, num_temporal_views_override) | |||||
| else: | |||||
| with TemporaryDirectory() as temporary_cache_dir: | |||||
| random_str = uuid.uuid4().hex | |||||
| http_get_file( | |||||
| url=video_path, | |||||
| local_dir=temporary_cache_dir, | |||||
| file_name=random_str, | |||||
| cookies=None) | |||||
| temp_file_path = os.path.join(temporary_cache_dir, random_str) | |||||
| data = _decode_video(cfg, temp_file_path, | |||||
| num_temporal_views_override) | |||||
| if num_spatial_crops_override is not None: | if num_spatial_crops_override is not None: | ||||
| num_spatial_crops = num_spatial_crops_override | num_spatial_crops = num_spatial_crops_override | ||||
| transform = kinetics400_tranform(cfg, num_spatial_crops_override) | transform = kinetics400_tranform(cfg, num_spatial_crops_override) | ||||
| @@ -47,7 +47,7 @@ class LrSchedulerHook(Hook): | |||||
| return lr | return lr | ||||
| def before_train_iter(self, trainer): | def before_train_iter(self, trainer): | ||||
| if not self.by_epoch: | |||||
| if not self.by_epoch and trainer.iter > 0: | |||||
| if self.warmup_lr_scheduler is not None: | if self.warmup_lr_scheduler is not None: | ||||
| self.warmup_lr_scheduler.step() | self.warmup_lr_scheduler.step() | ||||
| else: | else: | ||||
| @@ -656,7 +656,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| # TODO: support MsDataset load for cv | # TODO: support MsDataset load for cv | ||||
| if hasattr(data_cfg, 'name'): | if hasattr(data_cfg, 'name'): | ||||
| dataset = MsDataset.load( | dataset = MsDataset.load( | ||||
| dataset_name=data_cfg.name, | |||||
| dataset_name=data_cfg.pop('name'), | |||||
| **data_cfg, | **data_cfg, | ||||
| ) | ) | ||||
| cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | ||||
| @@ -57,6 +57,7 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||||
| def extract_pcm_from_wav(wav: bytes) -> bytes: | def extract_pcm_from_wav(wav: bytes) -> bytes: | ||||
| data = wav | data = wav | ||||
| sample_rate = None | |||||
| if len(data) > 44: | if len(data) > 44: | ||||
| frame_len = 44 | frame_len = 44 | ||||
| file_len = len(data) | file_len = len(data) | ||||
| @@ -70,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: | |||||
| 'Subchunk1ID'] == 'fmt ': | 'Subchunk1ID'] == 'fmt ': | ||||
| header_fields['SubChunk1Size'] = struct.unpack( | header_fields['SubChunk1Size'] = struct.unpack( | ||||
| '<I', data[16:20])[0] | '<I', data[16:20])[0] | ||||
| header_fields['SampleRate'] = struct.unpack('<I', | |||||
| data[24:28])[0] | |||||
| sample_rate = header_fields['SampleRate'] | |||||
| if header_fields['SubChunk1Size'] == 16: | if header_fields['SubChunk1Size'] == 16: | ||||
| frame_len = 44 | frame_len = 44 | ||||
| elif header_fields['SubChunk1Size'] == 18: | elif header_fields['SubChunk1Size'] == 18: | ||||
| frame_len = 46 | frame_len = 46 | ||||
| else: | else: | ||||
| return data | |||||
| return data, sample_rate | |||||
| data = wav[frame_len:file_len] | data = wav[frame_len:file_len] | ||||
| except Exception: | except Exception: | ||||
| # no treatment | # no treatment | ||||
| pass | pass | ||||
| return data | |||||
| return data, sample_rate | |||||
| def load_bytes_from_url(url: str) -> Union[bytes, str]: | def load_bytes_from_url(url: str) -> Union[bytes, str]: | ||||
| sample_rate = None | |||||
| result = urlparse(url) | result = urlparse(url) | ||||
| if result.scheme is not None and len(result.scheme) > 0: | if result.scheme is not None and len(result.scheme) > 0: | ||||
| storage = HTTPStorage() | storage = HTTPStorage() | ||||
| data = storage.read(url) | data = storage.read(url) | ||||
| data = extract_pcm_from_wav(data) | |||||
| data, sample_rate = extract_pcm_from_wav(data) | |||||
| else: | else: | ||||
| data = url | data = url | ||||
| return data | |||||
| return data, sample_rate | |||||
| @@ -231,13 +231,6 @@ class DownloadMode(enum.Enum): | |||||
| FORCE_REDOWNLOAD = 'force_redownload' | FORCE_REDOWNLOAD = 'force_redownload' | ||||
| class DownloadParams(enum.Enum): | |||||
| """ | |||||
| Parameters for downloading dataset. | |||||
| """ | |||||
| MAX_LIST_OBJECTS_NUM = 50000 | |||||
| class DatasetFormations(enum.Enum): | class DatasetFormations(enum.Enum): | ||||
| """ How a dataset is organized and interpreted | """ How a dataset is organized and interpreted | ||||
| """ | """ | ||||
| @@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'): | |||||
| if framework == Frameworks.tf: | if framework == Frameworks.tf: | ||||
| import tensorflow as tf | import tensorflow as tf | ||||
| if device_type == Devices.gpu and not tf.test.is_gpu_available(): | if device_type == Devices.gpu and not tf.test.is_gpu_available(): | ||||
| logger.warning( | |||||
| 'tensorflow cuda is not available, using cpu instead.') | |||||
| logger.debug( | |||||
| 'tensorflow: cuda is not available, using cpu instead.') | |||||
| device_type = Devices.cpu | device_type = Devices.cpu | ||||
| if device_type == Devices.cpu: | if device_type == Devices.cpu: | ||||
| with tf.device('/CPU:0'): | with tf.device('/CPU:0'): | ||||
| @@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'): | |||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| torch.cuda.set_device(f'cuda:{device_id}') | torch.cuda.set_device(f'cuda:{device_id}') | ||||
| else: | else: | ||||
| logger.warning('cuda is not available, using cpu instead.') | |||||
| logger.debug( | |||||
| 'pytorch: cuda is not available, using cpu instead.') | |||||
| yield | yield | ||||
| else: | else: | ||||
| yield | yield | ||||
| @@ -96,9 +97,7 @@ def create_device(device_name): | |||||
| if device_type == Devices.gpu: | if device_type == Devices.gpu: | ||||
| use_cuda = True | use_cuda = True | ||||
| if not torch.cuda.is_available(): | if not torch.cuda.is_available(): | ||||
| logger.warning( | |||||
| 'cuda is not available, create gpu device failed, using cpu instead.' | |||||
| ) | |||||
| logger.info('cuda is not available, using cpu instead.') | |||||
| use_cuda = False | use_cuda = False | ||||
| if use_cuda: | if use_cuda: | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import inspect | import inspect | ||||
| import os | |||||
| from pathlib import Path | from pathlib import Path | ||||
| @@ -35,3 +36,10 @@ def get_default_cache_dir(): | |||||
| """ | """ | ||||
| default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | ||||
| return default_cache_dir | return default_cache_dir | ||||
| def read_file(path): | |||||
| with open(path, 'r') as f: | |||||
| text = f.read() | |||||
| return text | |||||
| @@ -176,7 +176,7 @@ def build_from_cfg(cfg, | |||||
| raise TypeError('default_args must be a dict or None, ' | raise TypeError('default_args must be a dict or None, ' | ||||
| f'but got {type(default_args)}') | f'but got {type(default_args)}') | ||||
| # dynamic load installation reqruiements for this module | |||||
| # dynamic load installation requirements for this module | |||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| sig = (registry.name.upper(), group_key, cfg['type']) | sig = (registry.name.upper(), group_key, cfg['type']) | ||||
| LazyImportModule.import_module(sig) | LazyImportModule.import_module(sig) | ||||
| @@ -193,8 +193,11 @@ def build_from_cfg(cfg, | |||||
| if isinstance(obj_type, str): | if isinstance(obj_type, str): | ||||
| obj_cls = registry.get(obj_type, group_key=group_key) | obj_cls = registry.get(obj_type, group_key=group_key) | ||||
| if obj_cls is None: | if obj_cls is None: | ||||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||||
| f' registry group {group_key}') | |||||
| raise KeyError( | |||||
| f'{obj_type} is not in the {registry.name}' | |||||
| f' registry group {group_key}. Please make' | |||||
| f' sure the correct version of 1qqQModelScope library is used.' | |||||
| ) | |||||
| obj_cls.group_key = group_key | obj_cls.group_key = group_key | ||||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | ||||
| obj_cls = obj_type | obj_cls = obj_type | ||||
| @@ -65,7 +65,8 @@ class RegressTool: | |||||
| def monitor_module_single_forward(self, | def monitor_module_single_forward(self, | ||||
| module: nn.Module, | module: nn.Module, | ||||
| file_name: str, | file_name: str, | ||||
| compare_fn=None): | |||||
| compare_fn=None, | |||||
| **kwargs): | |||||
| """Monitor a pytorch module in a single forward. | """Monitor a pytorch module in a single forward. | ||||
| @param module: A torch module | @param module: A torch module | ||||
| @@ -107,7 +108,7 @@ class RegressTool: | |||||
| baseline = os.path.join(tempfile.gettempdir(), name) | baseline = os.path.join(tempfile.gettempdir(), name) | ||||
| self.load(baseline, name) | self.load(baseline, name) | ||||
| with open(baseline, 'rb') as f: | with open(baseline, 'rb') as f: | ||||
| baseline_json = pickle.load(f) | |||||
| base = pickle.load(f) | |||||
| class NumpyEncoder(json.JSONEncoder): | class NumpyEncoder(json.JSONEncoder): | ||||
| """Special json encoder for numpy types | """Special json encoder for numpy types | ||||
| @@ -122,9 +123,9 @@ class RegressTool: | |||||
| return obj.tolist() | return obj.tolist() | ||||
| return json.JSONEncoder.default(self, obj) | return json.JSONEncoder.default(self, obj) | ||||
| print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') | |||||
| print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') | |||||
| print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | ||||
| if not compare_io_and_print(baseline_json, io_json, compare_fn): | |||||
| if not compare_io_and_print(base, io_json, compare_fn, **kwargs): | |||||
| raise ValueError('Result not match!') | raise ValueError('Result not match!') | ||||
| @contextlib.contextmanager | @contextlib.contextmanager | ||||
| @@ -136,7 +137,8 @@ class RegressTool: | |||||
| ignore_keys=None, | ignore_keys=None, | ||||
| compare_random=True, | compare_random=True, | ||||
| reset_dropout=True, | reset_dropout=True, | ||||
| lazy_stop_callback=None): | |||||
| lazy_stop_callback=None, | |||||
| **kwargs): | |||||
| """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. | """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. | ||||
| This is usually useful when you try to change some dangerous code | This is usually useful when you try to change some dangerous code | ||||
| @@ -265,14 +267,15 @@ class RegressTool: | |||||
| baseline_json = pickle.load(f) | baseline_json = pickle.load(f) | ||||
| if level == 'strict' and not compare_io_and_print( | if level == 'strict' and not compare_io_and_print( | ||||
| baseline_json['forward'], io_json, compare_fn): | |||||
| baseline_json['forward'], io_json, compare_fn, **kwargs): | |||||
| raise RuntimeError('Forward not match!') | raise RuntimeError('Forward not match!') | ||||
| if not compare_backward_and_print( | if not compare_backward_and_print( | ||||
| baseline_json['backward'], | baseline_json['backward'], | ||||
| bw_json, | bw_json, | ||||
| compare_fn=compare_fn, | compare_fn=compare_fn, | ||||
| ignore_keys=ignore_keys, | ignore_keys=ignore_keys, | ||||
| level=level): | |||||
| level=level, | |||||
| **kwargs): | |||||
| raise RuntimeError('Backward not match!') | raise RuntimeError('Backward not match!') | ||||
| cfg_opt1 = { | cfg_opt1 = { | ||||
| 'optimizer': baseline_json['optimizer'], | 'optimizer': baseline_json['optimizer'], | ||||
| @@ -286,7 +289,8 @@ class RegressTool: | |||||
| 'cfg': summary['cfg'], | 'cfg': summary['cfg'], | ||||
| 'state': None if not compare_random else summary['state'] | 'state': None if not compare_random else summary['state'] | ||||
| } | } | ||||
| if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): | |||||
| if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, | |||||
| **kwargs): | |||||
| raise RuntimeError('Cfg or optimizers not match!') | raise RuntimeError('Cfg or optimizers not match!') | ||||
| @@ -303,7 +307,8 @@ class MsRegressTool(RegressTool): | |||||
| compare_fn=None, | compare_fn=None, | ||||
| ignore_keys=None, | ignore_keys=None, | ||||
| compare_random=True, | compare_random=True, | ||||
| lazy_stop_callback=None): | |||||
| lazy_stop_callback=None, | |||||
| **kwargs): | |||||
| if lazy_stop_callback is None: | if lazy_stop_callback is None: | ||||
| @@ -319,7 +324,7 @@ class MsRegressTool(RegressTool): | |||||
| trainer.register_hook(EarlyStopHook()) | trainer.register_hook(EarlyStopHook()) | ||||
| def _train_loop(trainer, *args, **kwargs): | |||||
| def _train_loop(trainer, *args_train, **kwargs_train): | |||||
| with self.monitor_module_train( | with self.monitor_module_train( | ||||
| trainer, | trainer, | ||||
| file_name, | file_name, | ||||
| @@ -327,9 +332,11 @@ class MsRegressTool(RegressTool): | |||||
| compare_fn=compare_fn, | compare_fn=compare_fn, | ||||
| ignore_keys=ignore_keys, | ignore_keys=ignore_keys, | ||||
| compare_random=compare_random, | compare_random=compare_random, | ||||
| lazy_stop_callback=lazy_stop_callback): | |||||
| lazy_stop_callback=lazy_stop_callback, | |||||
| **kwargs): | |||||
| try: | try: | ||||
| return trainer.train_loop_origin(*args, **kwargs) | |||||
| return trainer.train_loop_origin(*args_train, | |||||
| **kwargs_train) | |||||
| except MsRegressTool.EarlyStopError: | except MsRegressTool.EarlyStopError: | ||||
| pass | pass | ||||
| @@ -530,7 +537,8 @@ def compare_arguments_nested(print_content, | |||||
| ) | ) | ||||
| return False | return False | ||||
| if not all([ | if not all([ | ||||
| compare_arguments_nested(None, sub_arg1, sub_arg2) | |||||
| compare_arguments_nested( | |||||
| None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) | |||||
| for sub_arg1, sub_arg2 in zip(arg1, arg2) | for sub_arg1, sub_arg2 in zip(arg1, arg2) | ||||
| ]): | ]): | ||||
| if print_content is not None: | if print_content is not None: | ||||
| @@ -551,7 +559,8 @@ def compare_arguments_nested(print_content, | |||||
| print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') | print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') | ||||
| return False | return False | ||||
| if not all([ | if not all([ | ||||
| compare_arguments_nested(None, arg1[key], arg2[key]) | |||||
| compare_arguments_nested( | |||||
| None, arg1[key], arg2[key], rtol=rtol, atol=atol) | |||||
| for key in keys1 | for key in keys1 | ||||
| ]): | ]): | ||||
| if print_content is not None: | if print_content is not None: | ||||
| @@ -574,7 +583,7 @@ def compare_arguments_nested(print_content, | |||||
| raise ValueError(f'type not supported: {type1}') | raise ValueError(f'type not supported: {type1}') | ||||
| def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
| def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): | |||||
| if compare_fn is None: | if compare_fn is None: | ||||
| def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
| @@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
| else: | else: | ||||
| match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
| f'unmatched module {key} input args', v1input['args'], | f'unmatched module {key} input args', v1input['args'], | ||||
| v2input['args']) and match | |||||
| v2input['args'], **kwargs) and match | |||||
| match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
| f'unmatched module {key} input kwargs', v1input['kwargs'], | f'unmatched module {key} input kwargs', v1input['kwargs'], | ||||
| v2input['kwargs']) and match | |||||
| v2input['kwargs'], **kwargs) and match | |||||
| v1output = numpify_tensor_nested(v1['output']) | v1output = numpify_tensor_nested(v1['output']) | ||||
| v2output = numpify_tensor_nested(v2['output']) | v2output = numpify_tensor_nested(v2['output']) | ||||
| res = compare_fn(v1output, v2output, key, 'output') | res = compare_fn(v1output, v2output, key, 'output') | ||||
| @@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
| ) | ) | ||||
| match = match and res | match = match and res | ||||
| else: | else: | ||||
| match = compare_arguments_nested(f'unmatched module {key} outputs', | |||||
| v1output, v2output) and match | |||||
| match = compare_arguments_nested( | |||||
| f'unmatched module {key} outputs', | |||||
| arg1=v1output, | |||||
| arg2=v2output, | |||||
| **kwargs) and match | |||||
| return match | return match | ||||
| @@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json, | |||||
| bw_json, | bw_json, | ||||
| level, | level, | ||||
| ignore_keys=None, | ignore_keys=None, | ||||
| compare_fn=None): | |||||
| compare_fn=None, | |||||
| **kwargs): | |||||
| if compare_fn is None: | if compare_fn is None: | ||||
| def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
| @@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json, | |||||
| data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ | data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ | ||||
| 'grad'], bw_json[key]['data_after'] | 'grad'], bw_json[key]['data_after'] | ||||
| match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
| f'unmatched module {key} tensor data', data1, data2) and match | |||||
| f'unmatched module {key} tensor data', | |||||
| arg1=data1, | |||||
| arg2=data2, | |||||
| **kwargs) and match | |||||
| if level == 'strict': | if level == 'strict': | ||||
| match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
| f'unmatched module {key} grad data', grad1, | |||||
| grad2) and match | |||||
| f'unmatched module {key} grad data', | |||||
| arg1=grad1, | |||||
| arg2=grad2, | |||||
| **kwargs) and match | |||||
| match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
| f'unmatched module {key} data after step', data_after1, | f'unmatched module {key} data after step', data_after1, | ||||
| data_after2) and match | |||||
| data_after2, **kwargs) and match | |||||
| return match | return match | ||||
| def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
| def compare_cfg_and_optimizers(baseline_json, | |||||
| cfg_json, | |||||
| compare_fn=None, | |||||
| **kwargs): | |||||
| if compare_fn is None: | if compare_fn is None: | ||||
| def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
| @@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
| print( | print( | ||||
| f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" | f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" | ||||
| ) | ) | ||||
| match = compare_arguments_nested('unmatched optimizer defaults', | |||||
| optimizer1['defaults'], | |||||
| optimizer2['defaults']) and match | |||||
| match = compare_arguments_nested('unmatched optimizer state_dict', | |||||
| optimizer1['state_dict'], | |||||
| optimizer2['state_dict']) and match | |||||
| match = compare_arguments_nested( | |||||
| 'unmatched optimizer defaults', optimizer1['defaults'], | |||||
| optimizer2['defaults'], **kwargs) and match | |||||
| match = compare_arguments_nested( | |||||
| 'unmatched optimizer state_dict', optimizer1['state_dict'], | |||||
| optimizer2['state_dict'], **kwargs) and match | |||||
| res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') | res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') | ||||
| if res is not None: | if res is not None: | ||||
| @@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
| print( | print( | ||||
| f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" | f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" | ||||
| ) | ) | ||||
| match = compare_arguments_nested('unmatched lr_scheduler state_dict', | |||||
| lr_scheduler1['state_dict'], | |||||
| lr_scheduler2['state_dict']) and match | |||||
| match = compare_arguments_nested( | |||||
| 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], | |||||
| lr_scheduler2['state_dict'], **kwargs) and match | |||||
| res = compare_fn(cfg1, cfg2, None, 'cfg') | res = compare_fn(cfg1, cfg2, None, 'cfg') | ||||
| if res is not None: | if res is not None: | ||||
| print(f'cfg compared with user compare_fn with result:{res}\n') | print(f'cfg compared with user compare_fn with result:{res}\n') | ||||
| match = match and res | match = match and res | ||||
| else: | else: | ||||
| match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match | |||||
| match = compare_arguments_nested( | |||||
| 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match | |||||
| res = compare_fn(state1, state2, None, 'state') | res = compare_fn(state1, state2, None, 'state') | ||||
| if res is not None: | if res is not None: | ||||
| @@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
| match = match and res | match = match and res | ||||
| else: | else: | ||||
| match = compare_arguments_nested('unmatched random state', state1, | match = compare_arguments_nested('unmatched random state', state1, | ||||
| state2) and match | |||||
| state2, **kwargs) and match | |||||
| return match | return match | ||||
| @@ -19,7 +19,7 @@ moviepy>=1.0.3 | |||||
| networkx>=2.5 | networkx>=2.5 | ||||
| numba | numba | ||||
| onnxruntime>=1.10 | onnxruntime>=1.10 | ||||
| pai-easycv>=0.6.3.7 | |||||
| pai-easycv>=0.6.3.9 | |||||
| pandas | pandas | ||||
| psutil | psutil | ||||
| regex | regex | ||||
| @@ -127,7 +127,7 @@ class HubOperationTest(unittest.TestCase): | |||||
| return None | return None | ||||
| def test_list_model(self): | def test_list_model(self): | ||||
| data = self.api.list_model(TEST_MODEL_ORG) | |||||
| data = self.api.list_models(TEST_MODEL_ORG) | |||||
| assert len(data['Models']) >= 1 | assert len(data['Models']) >= 1 | ||||
| @@ -7,12 +7,12 @@ import uuid | |||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
| from modelscope.hub.errors import HTTPError, NotLoginException | |||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| from modelscope.hub.upload import upload_folder | |||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| from .test_utils import TEST_ACCESS_TOKEN1, delete_credential | |||||
| from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -22,7 +22,7 @@ class HubUploadTest(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| logger.info('SetUp') | logger.info('SetUp') | ||||
| self.api = HubApi() | self.api = HubApi() | ||||
| self.user = os.environ.get('TEST_MODEL_ORG', 'citest') | |||||
| self.user = TEST_MODEL_ORG | |||||
| logger.info(self.user) | logger.info(self.user) | ||||
| self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', | self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', | ||||
| uuid.uuid4().hex) | uuid.uuid4().hex) | ||||
| @@ -39,7 +39,10 @@ class HubUploadTest(unittest.TestCase): | |||||
| def tearDown(self): | def tearDown(self): | ||||
| logger.info('TearDown') | logger.info('TearDown') | ||||
| shutil.rmtree(self.model_dir, ignore_errors=True) | shutil.rmtree(self.model_dir, ignore_errors=True) | ||||
| self.api.delete_model(model_id=self.create_model_name) | |||||
| try: | |||||
| self.api.delete_model(model_id=self.create_model_name) | |||||
| except Exception: | |||||
| pass | |||||
| def test_upload_exits_repo_master(self): | def test_upload_exits_repo_master(self): | ||||
| logger.info('basic test for upload!') | logger.info('basic test for upload!') | ||||
| @@ -50,14 +53,14 @@ class HubUploadTest(unittest.TestCase): | |||||
| license=Licenses.APACHE_V2) | license=Licenses.APACHE_V2) | ||||
| os.system("echo '111'>%s" | os.system("echo '111'>%s" | ||||
| % os.path.join(self.finetune_path, 'add1.py')) | % os.path.join(self.finetune_path, 'add1.py')) | ||||
| upload_folder( | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, model_dir=self.finetune_path) | model_id=self.create_model_name, model_dir=self.finetune_path) | ||||
| Repository(model_dir=self.repo_path, clone_from=self.create_model_name) | Repository(model_dir=self.repo_path, clone_from=self.create_model_name) | ||||
| assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) | assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) | ||||
| shutil.rmtree(self.repo_path, ignore_errors=True) | shutil.rmtree(self.repo_path, ignore_errors=True) | ||||
| os.system("echo '222'>%s" | os.system("echo '222'>%s" | ||||
| % os.path.join(self.finetune_path, 'add2.py')) | % os.path.join(self.finetune_path, 'add2.py')) | ||||
| upload_folder( | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, | model_id=self.create_model_name, | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| revision='new_revision/version1') | revision='new_revision/version1') | ||||
| @@ -69,7 +72,7 @@ class HubUploadTest(unittest.TestCase): | |||||
| shutil.rmtree(self.repo_path, ignore_errors=True) | shutil.rmtree(self.repo_path, ignore_errors=True) | ||||
| os.system("echo '333'>%s" | os.system("echo '333'>%s" | ||||
| % os.path.join(self.finetune_path, 'add3.py')) | % os.path.join(self.finetune_path, 'add3.py')) | ||||
| upload_folder( | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, | model_id=self.create_model_name, | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| revision='new_revision/version2', | revision='new_revision/version2', | ||||
| @@ -84,7 +87,7 @@ class HubUploadTest(unittest.TestCase): | |||||
| add4_path = os.path.join(self.finetune_path, 'temp') | add4_path = os.path.join(self.finetune_path, 'temp') | ||||
| os.mkdir(add4_path) | os.mkdir(add4_path) | ||||
| os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) | os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) | ||||
| upload_folder( | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, | model_id=self.create_model_name, | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| revision='new_revision/version1') | revision='new_revision/version1') | ||||
| @@ -101,7 +104,7 @@ class HubUploadTest(unittest.TestCase): | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| os.system("echo '111'>%s" | os.system("echo '111'>%s" | ||||
| % os.path.join(self.finetune_path, 'add1.py')) | % os.path.join(self.finetune_path, 'add1.py')) | ||||
| upload_folder( | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, | model_id=self.create_model_name, | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| revision='new_model_new_revision', | revision='new_model_new_revision', | ||||
| @@ -119,48 +122,23 @@ class HubUploadTest(unittest.TestCase): | |||||
| logger.info('test upload without login!') | logger.info('test upload without login!') | ||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| delete_credential() | delete_credential() | ||||
| try: | |||||
| upload_folder( | |||||
| model_id=self.create_model_name, | |||||
| model_dir=self.finetune_path, | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2) | |||||
| except Exception as e: | |||||
| logger.info(e) | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | |||||
| upload_folder( | |||||
| with self.assertRaises(NotLoginException): | |||||
| self.api.push_model( | |||||
| model_id=self.create_model_name, | model_id=self.create_model_name, | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| visibility=ModelVisibility.PUBLIC, | visibility=ModelVisibility.PUBLIC, | ||||
| license=Licenses.APACHE_V2) | license=Licenses.APACHE_V2) | ||||
| Repository( | |||||
| model_dir=self.repo_path, clone_from=self.create_model_name) | |||||
| assert os.path.exists( | |||||
| os.path.join(self.repo_path, 'configuration.json')) | |||||
| shutil.rmtree(self.repo_path, ignore_errors=True) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_upload_invalid_repo(self): | def test_upload_invalid_repo(self): | ||||
| logger.info('test upload to invalid repo!') | logger.info('test upload to invalid repo!') | ||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| try: | |||||
| upload_folder( | |||||
| with self.assertRaises(HTTPError): | |||||
| self.api.push_model( | |||||
| model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | ||||
| model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
| visibility=ModelVisibility.PUBLIC, | visibility=ModelVisibility.PUBLIC, | ||||
| license=Licenses.APACHE_V2) | license=Licenses.APACHE_V2) | ||||
| except Exception as e: | |||||
| logger.info(e) | |||||
| upload_folder( | |||||
| model_id=self.create_model_name, | |||||
| model_dir=self.finetune_path, | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2) | |||||
| Repository( | |||||
| model_dir=self.repo_path, clone_from=self.create_model_name) | |||||
| assert os.path.exists( | |||||
| os.path.join(self.repo_path, 'configuration.json')) | |||||
| shutil.rmtree(self.repo_path, ignore_errors=True) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_ms_csv_basic(self): | def test_ms_csv_basic(self): | ||||
| ms_ds_train = MsDataset.load( | ms_ds_train = MsDataset.load( | ||||
| 'afqmc_small', namespace='userxiaoming', split='train') | |||||
| 'clue', subset_name='afqmc', | |||||
| split='train').to_hf_dataset().select(range(5)) | |||||
| print(next(iter(ms_ds_train))) | print(next(iter(ms_ds_train))) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| @@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| 'checking_item': OutputKeys.TEXT, | 'checking_item': OutputKeys.TEXT, | ||||
| 'example': 'wav_example' | 'example': 'wav_example' | ||||
| }, | }, | ||||
| 'test_run_with_url_pytorch': { | |||||
| 'checking_item': OutputKeys.TEXT, | |||||
| 'example': 'wav_example' | |||||
| }, | |||||
| 'test_run_with_url_tf': { | 'test_run_with_url_tf': { | ||||
| 'checking_item': OutputKeys.TEXT, | 'checking_item': OutputKeys.TEXT, | ||||
| 'example': 'wav_example' | 'example': 'wav_example' | ||||
| @@ -74,6 +78,170 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| } | } | ||||
| } | } | ||||
| all_models_info = [ | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', | |||||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_en.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_en.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_ru.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_ru.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_es.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_es.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_ko.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_ko.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_ja.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_ja.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', | |||||
| 'wav_path': 'data/test/audios/asr_example_id.wav' | |||||
| }, | |||||
| { | |||||
| 'model_group': 'damo', | |||||
| 'model_id': | |||||
| 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', | |||||
| 'wav_path': 'data/test/audios/asr_example_id.wav' | |||||
| }, | |||||
| ] | |||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' | ||||
| self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' | ||||
| @@ -90,7 +258,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| def run_pipeline(self, | def run_pipeline(self, | ||||
| model_id: str, | model_id: str, | ||||
| audio_in: Union[str, bytes], | audio_in: Union[str, bytes], | ||||
| sr: int = 16000) -> Dict[str, Any]: | |||||
| sr: int = None) -> Dict[str, Any]: | |||||
| inference_16k_pipline = pipeline( | inference_16k_pipline = pipeline( | ||||
| task=Tasks.auto_speech_recognition, model=model_id) | task=Tasks.auto_speech_recognition, model=model_id) | ||||
| @@ -136,33 +304,26 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| return audio, fs | return audio, fs | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_wav_pytorch(self): | |||||
| """run with single waveform file | |||||
| def test_run_with_pcm(self): | |||||
| """run with wav data | |||||
| """ | """ | ||||
| logger.info('Run ASR test with waveform file (pytorch)...') | |||||
| logger.info('Run ASR test with wav data (tensorflow)...') | |||||
| wav_file_path = os.path.join(os.getcwd(), WAV_FILE) | |||||
| audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
| rec_result = self.run_pipeline( | rec_result = self.run_pipeline( | ||||
| model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||||
| self.check_result('test_run_with_wav_pytorch', rec_result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_pcm_pytorch(self): | |||||
| """run with wav data | |||||
| """ | |||||
| model_id=self.am_tf_model_id, audio_in=audio, sr=sr) | |||||
| self.check_result('test_run_with_pcm_tf', rec_result) | |||||
| logger.info('Run ASR test with wav data (pytorch)...') | logger.info('Run ASR test with wav data (pytorch)...') | ||||
| audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
| rec_result = self.run_pipeline( | rec_result = self.run_pipeline( | ||||
| model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) | model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) | ||||
| self.check_result('test_run_with_pcm_pytorch', rec_result) | self.check_result('test_run_with_pcm_pytorch', rec_result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_wav_tf(self): | |||||
| def test_run_with_wav(self): | |||||
| """run with single waveform file | """run with single waveform file | ||||
| """ | """ | ||||
| @@ -174,21 +335,14 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| model_id=self.am_tf_model_id, audio_in=wav_file_path) | model_id=self.am_tf_model_id, audio_in=wav_file_path) | ||||
| self.check_result('test_run_with_wav_tf', rec_result) | self.check_result('test_run_with_wav_tf', rec_result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_pcm_tf(self): | |||||
| """run with wav data | |||||
| """ | |||||
| logger.info('Run ASR test with wav data (tensorflow)...') | |||||
| audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) | |||||
| logger.info('Run ASR test with waveform file (pytorch)...') | |||||
| rec_result = self.run_pipeline( | rec_result = self.run_pipeline( | ||||
| model_id=self.am_tf_model_id, audio_in=audio, sr=sr) | |||||
| self.check_result('test_run_with_pcm_tf', rec_result) | |||||
| model_id=self.am_pytorch_model_id, audio_in=wav_file_path) | |||||
| self.check_result('test_run_with_wav_pytorch', rec_result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_url_tf(self): | |||||
| def test_run_with_url(self): | |||||
| """run with single url file | """run with single url file | ||||
| """ | """ | ||||
| @@ -198,6 +352,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| model_id=self.am_tf_model_id, audio_in=URL_FILE) | model_id=self.am_tf_model_id, audio_in=URL_FILE) | ||||
| self.check_result('test_run_with_url_tf', rec_result) | self.check_result('test_run_with_url_tf', rec_result) | ||||
| logger.info('Run ASR test with url file (pytorch)...') | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_pytorch_model_id, audio_in=URL_FILE) | |||||
| self.check_result('test_run_with_url_pytorch', rec_result) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_wav_dataset_pytorch(self): | def test_run_with_wav_dataset_pytorch(self): | ||||
| """run with datasets, and audio format is waveform | """run with datasets, and audio format is waveform | ||||
| @@ -217,7 +377,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| data.text # hypothesis text | data.text # hypothesis text | ||||
| """ | """ | ||||
| logger.info('Run ASR test with waveform dataset (pytorch)...') | |||||
| logger.info('Downloading waveform testsets file ...') | logger.info('Downloading waveform testsets file ...') | ||||
| dataset_path = download_and_untar( | dataset_path = download_and_untar( | ||||
| @@ -225,40 +384,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||||
| LITTLE_TESTSETS_URL, self.workspace) | LITTLE_TESTSETS_URL, self.workspace) | ||||
| dataset_path = os.path.join(dataset_path, 'wav', 'test') | dataset_path = os.path.join(dataset_path, 'wav', 'test') | ||||
| logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_tf_model_id, audio_in=dataset_path) | |||||
| self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||||
| logger.info('Run ASR test with waveform dataset (pytorch)...') | |||||
| rec_result = self.run_pipeline( | rec_result = self.run_pipeline( | ||||
| model_id=self.am_pytorch_model_id, audio_in=dataset_path) | model_id=self.am_pytorch_model_id, audio_in=dataset_path) | ||||
| self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | self.check_result('test_run_with_wav_dataset_pytorch', rec_result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_wav_dataset_tf(self): | |||||
| """run with datasets, and audio format is waveform | |||||
| datasets directory: | |||||
| <dataset_path> | |||||
| wav | |||||
| test # testsets | |||||
| xx.wav | |||||
| ... | |||||
| dev # devsets | |||||
| yy.wav | |||||
| ... | |||||
| train # trainsets | |||||
| zz.wav | |||||
| ... | |||||
| transcript | |||||
| data.text # hypothesis text | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_all_models(self): | |||||
| """run with all models | |||||
| """ | """ | ||||
| logger.info('Run ASR test with waveform dataset (tensorflow)...') | |||||
| logger.info('Downloading waveform testsets file ...') | |||||
| dataset_path = download_and_untar( | |||||
| os.path.join(self.workspace, LITTLE_TESTSETS_FILE), | |||||
| LITTLE_TESTSETS_URL, self.workspace) | |||||
| dataset_path = os.path.join(dataset_path, 'wav', 'test') | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=self.am_tf_model_id, audio_in=dataset_path) | |||||
| self.check_result('test_run_with_wav_dataset_tf', rec_result) | |||||
| logger.info('Run ASR test with all models') | |||||
| for item in self.all_models_info: | |||||
| model_id = item['model_group'] + '/' + item['model_id'] | |||||
| wav_path = item['wav_path'] | |||||
| rec_result = self.run_pipeline( | |||||
| model_id=model_id, audio_in=wav_path) | |||||
| if rec_result.__contains__(OutputKeys.TEXT): | |||||
| logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' ' | |||||
| + ColorCodes.YELLOW | |||||
| + str(rec_result[OutputKeys.TEXT]) | |||||
| + ColorCodes.END) | |||||
| else: | |||||
| logger.info(ColorCodes.MAGENTA + str(rec_result) | |||||
| + ColorCodes.END) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| @@ -26,6 +26,20 @@ class TranslationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| pipeline_ins = pipeline(self.task, model=model_id) | pipeline_ins = pipeline(self.task, model=model_id) | ||||
| print(pipeline_ins(input=inputs)) | print(pipeline_ins(input=inputs)) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name_for_en2fr(self): | |||||
| model_id = 'damo/nlp_csanmt_translation_en2fr' | |||||
| inputs = 'When I was in my 20s, I saw my very first psychotherapy client.' | |||||
| pipeline_ins = pipeline(self.task, model=model_id) | |||||
| print(pipeline_ins(input=inputs)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name_for_fr2en(self): | |||||
| model_id = 'damo/nlp_csanmt_translation_fr2en' | |||||
| inputs = "Quand j'avais la vingtaine, j'ai vu mes tout premiers clients comme psychothérapeute." | |||||
| pipeline_ins = pipeline(self.task, model=model_id) | |||||
| print(pipeline_ins(input=inputs)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
| inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' | inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' | ||||
| @@ -4,22 +4,45 @@ import unittest | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class TinynasObjectDetectionTest(unittest.TestCase): | |||||
| class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.image_object_detection | |||||
| self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run(self): | |||||
| def test_run_airdet(self): | |||||
| tinynas_object_detection = pipeline( | tinynas_object_detection = pipeline( | ||||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | Tasks.image_object_detection, model='damo/cv_tinynas_detection') | ||||
| result = tinynas_object_detection( | result = tinynas_object_detection( | ||||
| 'data/test/images/image_detection.jpg') | 'data/test/images/image_detection.jpg') | ||||
| print(result) | print(result) | ||||
| @unittest.skip('will be enabled after damoyolo officially released') | |||||
| def test_run_damoyolo(self): | |||||
| tinynas_object_detection = pipeline( | |||||
| Tasks.image_object_detection, | |||||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||||
| result = tinynas_object_detection( | |||||
| 'data/test/images/image_detection.jpg') | |||||
| print(result) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.test_demo() | |||||
| self.compatibility_check() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_object_detection_auto_pipeline(self): | |||||
| test_image = 'data/test/images/image_detection.jpg' | |||||
| tinynas_object_detection = pipeline( | |||||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||||
| result = tinynas_object_detection(test_image) | |||||
| tinynas_object_detection.show_result(test_image, result, | |||||
| 'demo_ret.jpg') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import DownloadMode, LogKeys, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
| class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): | |||||
| model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||||
| def setUp(self): | |||||
| self.logger = get_logger() | |||||
| self.logger.info(('Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| def _train(self, tmp_dir): | |||||
| cfg_options = {'train.max_epochs': 2} | |||||
| trainer_name = Trainers.easycv | |||||
| train_dataset = MsDataset.load( | |||||
| dataset_name='face_2d_keypoints_dataset', | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| eval_dataset = MsDataset.load( | |||||
| dataset_name='face_2d_keypoints_dataset', | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| work_dir=tmp_dir, | |||||
| cfg_options=cfg_options) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_single_gpu(self): | |||||
| temp_file_dir = tempfile.TemporaryDirectory() | |||||
| tmp_dir = temp_file_dir.name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| self._train(tmp_dir) | |||||
| results_files = os.listdir(tmp_dir) | |||||
| json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| temp_file_dir.cleanup() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ | |||||
| calculate_fisher | calculate_fisher | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.data_utils import to_device | from modelscope.utils.data_utils import to_device | ||||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||||
| from modelscope.utils.regress_test_utils import (MsRegressTool, | |||||
| compare_arguments_nested) | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| def test_trainer_repeatable(self): | def test_trainer_repeatable(self): | ||||
| import torch # noqa | import torch # noqa | ||||
| def compare_fn(value1, value2, key, type): | |||||
| # Ignore the differences between optimizers of two torch versions | |||||
| if type != 'optimizer': | |||||
| return None | |||||
| match = (value1['type'] == value2['type']) | |||||
| shared_defaults = set(value1['defaults'].keys()).intersection( | |||||
| set(value2['defaults'].keys())) | |||||
| match = all([ | |||||
| compare_arguments_nested(f'Optimizer defaults {key} not match', | |||||
| value1['defaults'][key], | |||||
| value2['defaults'][key]) | |||||
| for key in shared_defaults | |||||
| ]) and match | |||||
| match = (len(value1['state_dict']['param_groups']) == len( | |||||
| value2['state_dict']['param_groups'])) and match | |||||
| for group1, group2 in zip(value1['state_dict']['param_groups'], | |||||
| value2['state_dict']['param_groups']): | |||||
| shared_keys = set(group1.keys()).intersection( | |||||
| set(group2.keys())) | |||||
| match = all([ | |||||
| compare_arguments_nested( | |||||
| f'Optimizer param_groups {key} not match', group1[key], | |||||
| group2[key]) for key in shared_keys | |||||
| ]) and match | |||||
| return match | |||||
| def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
| cfg.task = 'nli' | cfg.task = 'nli' | ||||
| cfg['preprocessor'] = {'type': 'nli-tokenizer'} | cfg['preprocessor'] = {'type': 'nli-tokenizer'} | ||||
| @@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | name=Trainers.nlp_base_trainer, default_args=kwargs) | ||||
| with self.regress_tool.monitor_ms_train( | with self.regress_tool.monitor_ms_train( | ||||
| trainer, 'sbert-base-tnews', level='strict'): | |||||
| trainer, 'sbert-base-tnews', level='strict', | |||||
| compare_fn=compare_fn): | |||||
| trainer.train() | trainer.train() | ||||
| def finetune(self, | def finetune(self, | ||||
| @@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | shutil.rmtree(self.tmp_dir, ignore_errors=True) | ||||
| super().tearDown() | super().tearDown() | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer(self): | def test_trainer(self): | ||||
| kwargs = dict( | kwargs = dict( | ||||
| model=self.model_id, | model=self.model_id, | ||||
| @@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| for i in range(2): | for i in range(2): | ||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | def test_trainer_with_model_and_args(self): | ||||
| model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | ||||
| kwargs = dict( | kwargs = dict( | ||||
| @@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| self.dataset = MsDataset.load( | self.dataset = MsDataset.load( | ||||
| 'afqmc_small', namespace='userxiaoming', split='train') | |||||
| 'clue', subset_name='afqmc', | |||||
| split='train').to_hf_dataset().select(range(2)) | |||||
| def tearDown(self): | def tearDown(self): | ||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| @@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | ||||
| pipeline_sentence_similarity(output_dir) | pipeline_sentence_similarity(output_dir) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') | |||||
| def test_trainer_with_backbone_head(self): | def test_trainer_with_backbone_head(self): | ||||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
| kwargs = dict( | kwargs = dict( | ||||
| @@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
| cfg = read_config(model_id, revision='beta') | cfg = read_config(model_id, revision='beta') | ||||
| cfg.train.max_epochs = 20 | cfg.train.max_epochs = 20 | ||||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
| cfg.dump(cfg_file) | cfg.dump(cfg_file) | ||||
| @@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | ||||
| self.assertTrue(Metrics.accuracy in eval_results) | self.assertTrue(Metrics.accuracy in eval_results) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_configured_datasets(self): | def test_trainer_with_configured_datasets(self): | ||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
| cfg: Config = read_config(model_id) | cfg: Config = read_config(model_id) | ||||
| cfg.train.max_epochs = 20 | cfg.train.max_epochs = 20 | ||||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
| cfg.dataset = { | cfg.dataset = { | ||||
| 'train': { | 'train': { | ||||
| 'name': 'afqmc_small', | |||||
| 'name': 'clue', | |||||
| 'subset_name': 'afqmc', | |||||
| 'split': 'train', | 'split': 'train', | ||||
| 'namespace': 'userxiaoming' | |||||
| }, | }, | ||||
| 'val': { | 'val': { | ||||
| 'name': 'afqmc_small', | |||||
| 'name': 'clue', | |||||
| 'subset_name': 'afqmc', | |||||
| 'split': 'train', | 'split': 'train', | ||||
| 'namespace': 'userxiaoming' | |||||
| }, | }, | ||||
| } | } | ||||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
| @@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
| cfg: Config = read_config(model_id) | cfg: Config = read_config(model_id) | ||||
| cfg.train.max_epochs = 3 | cfg.train.max_epochs = 3 | ||||
| cfg.preprocessor.first_sequence = 'sentence1' | |||||
| cfg.preprocessor.second_sequence = 'sentence2' | |||||
| cfg.preprocessor.label = 'label' | |||||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
| cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
| cfg.dump(cfg_file) | cfg.dump(cfg_file) | ||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| class CompatibilityTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| def test_xtcocotools(self): | |||||
| from xtcocotools.coco import COCO | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||