| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: https://gitlab.com/pycqa/flake8.git | - repo: https://gitlab.com/pycqa/flake8.git | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: /home/admin/pre-commit/flake8 | - repo: /home/admin/pre-commit/flake8 | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||||
| size 38954 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||||
| size 93676 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||||
| size 472012 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||||
| size 71244 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||||
| size 181964 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||||
| size 112078 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||||
| size 200556 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||||
| size 162636 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||||
| size 160204 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||||
| size 959 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 | |||||
| size 151572 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||||
| size 62231 | |||||
| oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b | |||||
| size 61741 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||||
| size 62235 | |||||
| oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 | |||||
| size 61745 | |||||
| @@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| def generate_dummy_inputs(self, | def generate_dummy_inputs(self, | ||||
| shape: Tuple = None, | shape: Tuple = None, | ||||
| pair: bool = False, | |||||
| **kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | """Generate dummy inputs for model exportation to onnx or other formats by tracing. | ||||
| @param shape: A tuple of input shape which should have at most two dimensions. | @param shape: A tuple of input shape which should have at most two dimensions. | ||||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | ||||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | ||||
| @param pair: Generate sentence pairs or single sentences for dummy inputs. | |||||
| @return: Dummy inputs. | @return: Dummy inputs. | ||||
| """ | """ | ||||
| @@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| **sequence_length | **sequence_length | ||||
| }) | }) | ||||
| preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | ||||
| if preprocessor.pair: | |||||
| if pair: | |||||
| first_sequence = preprocessor.tokenizer.unk_token | first_sequence = preprocessor.tokenizer.unk_token | ||||
| second_sequence = preprocessor.tokenizer.unk_token | second_sequence = preprocessor.tokenizer.unk_token | ||||
| else: | else: | ||||
| @@ -1,8 +1,11 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| # yapf: disable | |||||
| import datetime | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| @@ -16,17 +19,25 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | ||||
| API_RESPONSE_FIELD_MESSAGE, | API_RESPONSE_FIELD_MESSAGE, | ||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH) | |||||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||||
| ModelVisibility) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||||
| NotLoginException, RequestError, | |||||
| datahub_raise_on_error, | |||||
| handle_http_post_error, | |||||
| handle_http_response, is_ok, raise_on_error) | |||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.utils.utils import (get_endpoint, | |||||
| model_id_to_group_owner_name) | |||||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION, | DEFAULT_MODEL_REVISION, | ||||
| DatasetFormations, DatasetMetaFormats, | DatasetFormations, DatasetMetaFormats, | ||||
| DownloadMode) | |||||
| DownloadMode, ModelFile) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .errors import (InvalidParameter, NotExistError, RequestError, | |||||
| datahub_raise_on_error, handle_http_post_error, | |||||
| handle_http_response, is_ok, raise_on_error) | |||||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||||
| # yapf: enable | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -169,11 +180,106 @@ class HubApi: | |||||
| else: | else: | ||||
| r.raise_for_status() | r.raise_for_status() | ||||
| def list_model(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List model in owner or group. | |||||
| def push_model(self, | |||||
| model_id: str, | |||||
| model_dir: str, | |||||
| visibility: int = ModelVisibility.PUBLIC, | |||||
| license: str = Licenses.APACHE_V2, | |||||
| chinese_name: Optional[str] = None, | |||||
| commit_message: Optional[str] = 'upload model', | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||||
| """ | |||||
| Upload model from a given directory to given repository. A valid model directory | |||||
| must contain a configuration.json file. | |||||
| This function upload the files in given directory to given repository. If the | |||||
| given repository is not exists in remote, it will automatically create it with | |||||
| given visibility, license and chinese_name parameters. If the revision is also | |||||
| not exists in remote repository, it will create a new branch for it. | |||||
| This function must be called before calling HubApi's login with a valid token | |||||
| which can be obtained from ModelScope's website. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model id to be uploaded, caller must have write permission for it. | |||||
| model_dir(`str`): | |||||
| The Absolute Path of the finetune result. | |||||
| visibility(`int`, defaults to `0`): | |||||
| Visibility of the new created model(1-private, 5-public). If the model is | |||||
| not exists in ModelScope, this function will create a new model with this | |||||
| visibility and this parameter is required. You can ignore this parameter | |||||
| if you make sure the model's existence. | |||||
| license(`str`, defaults to `None`): | |||||
| License of the new created model(see License). If the model is not exists | |||||
| in ModelScope, this function will create a new model with this license | |||||
| and this parameter is required. You can ignore this parameter if you | |||||
| make sure the model's existence. | |||||
| chinese_name(`str`, *optional*, defaults to `None`): | |||||
| chinese name of the new created model. | |||||
| commit_message(`str`, *optional*, defaults to `None`): | |||||
| commit message of the push request. | |||||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||||
| which branch to push. If the branch is not exists, It will create a new | |||||
| branch and push to it. | |||||
| """ | |||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id cannot be empty!') | |||||
| if model_dir is None: | |||||
| raise InvalidParameter('model_dir cannot be empty!') | |||||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||||
| raise InvalidParameter('model_dir must be a valid directory.') | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| if not os.path.exists(cfg_file): | |||||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException('Must login before upload!') | |||||
| files_to_save = os.listdir(model_dir) | |||||
| try: | |||||
| self.get_model(model_id=model_id) | |||||
| except Exception: | |||||
| if visibility is None or license is None: | |||||
| raise InvalidParameter( | |||||
| 'visibility and license cannot be empty if want to create new repo' | |||||
| ) | |||||
| logger.info('Create new model %s' % model_id) | |||||
| self.create_model( | |||||
| model_id=model_id, | |||||
| visibility=visibility, | |||||
| license=license, | |||||
| chinese_name=chinese_name) | |||||
| tmp_dir = tempfile.mkdtemp() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| try: | |||||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||||
| if revision not in branches: | |||||
| logger.info('Create new branch %s' % revision) | |||||
| git_wrapper.new_branch(tmp_dir, revision) | |||||
| git_wrapper.checkout(tmp_dir, revision) | |||||
| for f in files_to_save: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(model_dir, f) | |||||
| if os.path.isdir(src): | |||||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||||
| else: | |||||
| shutil.copy(src, tmp_dir) | |||||
| if not commit_message: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||||
| model_id, date) | |||||
| repo.push(commit_message=commit_message, branch=revision) | |||||
| except Exception: | |||||
| raise | |||||
| finally: | |||||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||||
| def list_models(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List models in owner or group. | |||||
| Args: | Args: | ||||
| owner_or_group(`str`): owner or group. | owner_or_group(`str`): owner or group. | ||||
| @@ -390,11 +496,13 @@ class HubApi: | |||||
| return resp['Data'] | return resp['Data'] | ||||
| def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | ||||
| is_recursive, is_filter_dir, revision, | |||||
| cookies): | |||||
| is_recursive, is_filter_dir, revision): | |||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | ||||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | ||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies: | |||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||||
| resp = requests.get(url=url, cookies=cookies) | resp = requests.get(url=url, cookies=cookies) | ||||
| resp = resp.json() | resp = resp.json() | ||||
| @@ -11,13 +11,12 @@ from typing import Dict, Optional, Union | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import requests | import requests | ||||
| from filelock import FileLock | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import FileDownloadError, NotExistError | from .errors import FileDownloadError, NotExistError | ||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| @@ -1,13 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import re | |||||
| import subprocess | import subprocess | ||||
| from typing import List | from typing import List | ||||
| from xmlrpc.client import Boolean | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from .errors import GitError | from .errors import GitError | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -132,6 +129,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| return response | return response | ||||
| def add_user_info(self, repo_base_dir, repo_name): | def add_user_info(self, repo_base_dir, repo_name): | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| user_name, user_email = ModelScopeConfig.get_user_info() | user_name, user_email = ModelScopeConfig.get_user_info() | ||||
| if user_name and user_email: | if user_name and user_email: | ||||
| # config user.name and user.email if exist | # config user.name and user.email if exist | ||||
| @@ -184,8 +182,11 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| info = [ | info = [ | ||||
| line.strip() | line.strip() | ||||
| for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | ||||
| ][1:] | |||||
| return ['/'.join(line.split('/')[1:]) for line in info] | |||||
| ] | |||||
| if len(info) == 1: | |||||
| return ['/'.join(info[0].split('/')[1:])] | |||||
| else: | |||||
| return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||||
| def pull(self, repo_dir: str): | def pull(self, repo_dir: str): | ||||
| cmds = ['-C', repo_dir, 'pull'] | cmds = ['-C', repo_dir, 'pull'] | ||||
| @@ -7,7 +7,6 @@ from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION) | DEFAULT_MODEL_REVISION) | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from .git import GitCommandWrapper | from .git import GitCommandWrapper | ||||
| from .utils.utils import get_endpoint | from .utils.utils import get_endpoint | ||||
| @@ -47,6 +46,7 @@ class Repository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -166,7 +166,7 @@ class DatasetRepository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| self.revision = revision | self.revision = revision | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -5,9 +5,9 @@ import tempfile | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import NotExistError | from .errors import NotExistError | ||||
| from .file_download import (get_file_download_url, http_get_file, | from .file_download import (get_file_download_url, http_get_file, | ||||
| @@ -1,117 +0,0 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import datetime | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import uuid | |||||
| from typing import Dict, Optional | |||||
| from uuid import uuid4 | |||||
| from filelock import FileLock | |||||
| from modelscope import __version__ | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.hub.errors import InvalidParameter, NotLoginException | |||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def upload_folder(model_id: str, | |||||
| model_dir: str, | |||||
| visibility: int = 0, | |||||
| license: str = None, | |||||
| chinese_name: Optional[str] = None, | |||||
| commit_message: Optional[str] = None, | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||||
| """ | |||||
| Upload model from a given directory to given repository. A valid model directory | |||||
| must contain a configuration.json file. | |||||
| This function upload the files in given directory to given repository. If the | |||||
| given repository is not exists in remote, it will automatically create it with | |||||
| given visibility, license and chinese_name parameters. If the revision is also | |||||
| not exists in remote repository, it will create a new branch for it. | |||||
| This function must be called before calling HubApi's login with a valid token | |||||
| which can be obtained from ModelScope's website. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model id to be uploaded, caller must have write permission for it. | |||||
| model_dir(`str`): | |||||
| The Absolute Path of the finetune result. | |||||
| visibility(`int`, defaults to `0`): | |||||
| Visibility of the new created model(1-private, 5-public). If the model is | |||||
| not exists in ModelScope, this function will create a new model with this | |||||
| visibility and this parameter is required. You can ignore this parameter | |||||
| if you make sure the model's existence. | |||||
| license(`str`, defaults to `None`): | |||||
| License of the new created model(see License). If the model is not exists | |||||
| in ModelScope, this function will create a new model with this license | |||||
| and this parameter is required. You can ignore this parameter if you | |||||
| make sure the model's existence. | |||||
| chinese_name(`str`, *optional*, defaults to `None`): | |||||
| chinese name of the new created model. | |||||
| commit_message(`str`, *optional*, defaults to `None`): | |||||
| commit message of the push request. | |||||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||||
| which branch to push. If the branch is not exists, It will create a new | |||||
| branch and push to it. | |||||
| """ | |||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id cannot be empty!') | |||||
| if model_dir is None: | |||||
| raise InvalidParameter('model_dir cannot be empty!') | |||||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||||
| raise InvalidParameter('model_dir must be a valid directory.') | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| if not os.path.exists(cfg_file): | |||||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException('Must login before upload!') | |||||
| files_to_save = os.listdir(model_dir) | |||||
| api = HubApi() | |||||
| try: | |||||
| api.get_model(model_id=model_id) | |||||
| except Exception: | |||||
| if visibility is None or license is None: | |||||
| raise InvalidParameter( | |||||
| 'visibility and license cannot be empty if want to create new repo' | |||||
| ) | |||||
| logger.info('Create new model %s' % model_id) | |||||
| api.create_model( | |||||
| model_id=model_id, | |||||
| visibility=visibility, | |||||
| license=license, | |||||
| chinese_name=chinese_name) | |||||
| tmp_dir = tempfile.mkdtemp() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| try: | |||||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||||
| if revision not in branches: | |||||
| logger.info('Create new branch %s' % revision) | |||||
| git_wrapper.new_branch(tmp_dir, revision) | |||||
| git_wrapper.checkout(tmp_dir, revision) | |||||
| for f in files_to_save: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(model_dir, f) | |||||
| if os.path.isdir(src): | |||||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||||
| else: | |||||
| shutil.copy(src, tmp_dir) | |||||
| if not commit_message: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||||
| model_id, date) | |||||
| repo.push(commit_message=commit_message, branch=revision) | |||||
| except Exception: | |||||
| raise | |||||
| finally: | |||||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||||
| @@ -9,7 +9,9 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # tinynas models | |||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||||
| # vision models | # vision models | ||||
| detection = 'detection' | detection = 'detection' | ||||
| @@ -234,7 +236,7 @@ class Pipelines(object): | |||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | table_question_answering_pipeline = 'table-question-answering-pipeline' | ||||
| sentence_embedding = 'sentence-embedding' | sentence_embedding = 'sentence-embedding' | ||||
| passage_ranking = 'passage-ranking' | |||||
| text_ranking = 'text-ranking' | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| @@ -261,6 +263,7 @@ class Pipelines(object): | |||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -295,7 +298,7 @@ class Trainers(object): | |||||
| dialog_intent_trainer = 'dialog-intent-trainer' | dialog_intent_trainer = 'dialog-intent-trainer' | ||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| @@ -341,7 +344,7 @@ class Preprocessors(object): | |||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| sentence_embedding = 'sentence-embedding' | sentence_embedding = 'sentence-embedding' | ||||
| passage_ranking = 'passage-ranking' | |||||
| text_ranking = 'text-ranking' | |||||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | ||||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| @@ -374,7 +377,7 @@ class Metrics(object): | |||||
| audio_noise_metric = 'audio-noise-metric' | audio_noise_metric = 'audio-noise-metric' | ||||
| # text gen | # text gen | ||||
| bleu = 'bleu' | |||||
| BLEU = 'bleu' | |||||
| # metrics for image denoise task | # metrics for image denoise task | ||||
| image_denoise_metric = 'image-denoise-metric' | image_denoise_metric = 'image-denoise-metric' | ||||
| @@ -396,6 +399,8 @@ class Metrics(object): | |||||
| movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | ||||
| # metric for inpainting task | # metric for inpainting task | ||||
| image_inpainting_metric = 'image-inpainting-metric' | image_inpainting_metric = 'image-inpainting-metric' | ||||
| # metric for ocr | |||||
| NED = 'ned' | |||||
| class Optimizers(object): | class Optimizers(object): | ||||
| @@ -454,9 +459,10 @@ class Datasets(object): | |||||
| """ Names for different datasets. | """ Names for different datasets. | ||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||||
| Face2dKeypointsDataset = 'FaceKeypointDataset' | |||||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | ||||
| HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||||
| HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' | ||||
| PairedDataset = 'PairedDataset' | |||||
| @@ -11,7 +11,7 @@ from .builder import METRICS, MetricKeys | |||||
| EVAL_BLEU_ORDER = 4 | EVAL_BLEU_ORDER = 4 | ||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.bleu) | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) | |||||
| class BleuMetric(Metric): | class BleuMetric(Metric): | ||||
| """The metric computation bleu for text generation classes. | """The metric computation bleu for text generation classes. | ||||
| @@ -23,6 +23,7 @@ class MetricKeys(object): | |||||
| BLEU_4 = 'bleu-4' | BLEU_4 = 'bleu-4' | ||||
| ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
| ROUGE_L = 'rouge-l' | ROUGE_L = 'rouge-l' | ||||
| NED = 'ned' # ocr metric | |||||
| task_default_metrics = { | task_default_metrics = { | ||||
| @@ -32,6 +33,7 @@ task_default_metrics = { | |||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.token_classification: [Metrics.token_cls_metric], | Tasks.token_classification: [Metrics.token_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.text_classification: [Metrics.seq_cls_metric], | |||||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | Tasks.image_denoising: [Metrics.image_denoise_metric], | ||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | ||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| @@ -2,6 +2,7 @@ | |||||
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | ||||
| from typing import Dict | from typing import Dict | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| @@ -37,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): | |||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| ground_truths = outputs['target'] | ground_truths = outputs['target'] | ||||
| eval_results = outputs['pred'] | eval_results = outputs['pred'] | ||||
| self.preds.extend(eval_results) | self.preds.extend(eval_results) | ||||
| self.targets.extend(ground_truths) | self.targets.extend(ground_truths) | ||||
| @@ -2,6 +2,7 @@ | |||||
| import os | import os | ||||
| import pickle as pkl | import pickle as pkl | ||||
| from threading import Lock | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| @@ -27,6 +28,7 @@ class Voice: | |||||
| self.__am_config = AttrDict(**am_config) | self.__am_config = AttrDict(**am_config) | ||||
| self.__voc_config = AttrDict(**voc_config) | self.__voc_config = AttrDict(**voc_config) | ||||
| self.__model_loaded = False | self.__model_loaded = False | ||||
| self.__lock = Lock() | |||||
| if 'am' not in self.__am_config: | if 'am' not in self.__am_config: | ||||
| raise TtsModelConfigurationException( | raise TtsModelConfigurationException( | ||||
| 'modelscope error: am configuration invalid') | 'modelscope error: am configuration invalid') | ||||
| @@ -71,34 +73,35 @@ class Voice: | |||||
| self.__generator.remove_weight_norm() | self.__generator.remove_weight_norm() | ||||
| def __am_forward(self, symbol_seq): | def __am_forward(self, symbol_seq): | ||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||||
| self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| with self.__lock: | |||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy( | |||||
| inputs_feat_lst[2]).long().to(self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| def __vocoder_forward(self, melspec): | def __vocoder_forward(self, melspec): | ||||
| dim0 = list(melspec.shape)[-1] | dim0 = list(melspec.shape)[-1] | ||||
| @@ -118,14 +121,15 @@ class Voice: | |||||
| return audio | return audio | ||||
| def forward(self, symbol_seq): | def forward(self, symbol_seq): | ||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| if torch.cuda.is_available(): | |||||
| with self.__lock: | |||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | torch.manual_seed(self.__am_config.seed) | ||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| if torch.cuda.is_available(): | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| return self.__vocoder_forward(self.__am_forward(symbol_seq)) | return self.__vocoder_forward(self.__am_forward(symbol_seq)) | ||||
| @@ -56,9 +56,6 @@ class OneStageDetector(nn.Module): | |||||
| def inference(self, meta): | def inference(self, meta): | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| torch.cuda.synchronize() | |||||
| preds = self(meta['img']) | preds = self(meta['img']) | ||||
| torch.cuda.synchronize() | |||||
| results = self.head.post_process(preds, meta) | results = self.head.post_process(preds, meta) | ||||
| torch.cuda.synchronize() | |||||
| return results | return results | ||||
| @@ -35,7 +35,7 @@ class ImagePortraitEnhancement(TorchModel): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.size = 512 | |||||
| self.size = 256 | |||||
| self.style_dim = 512 | self.style_dim = 512 | ||||
| self.n_mlp = 8 | self.n_mlp = 8 | ||||
| self.mean_path_length = 0 | self.mean_path_length = 0 | ||||
| @@ -131,9 +131,9 @@ class ImagePortraitEnhancement(TorchModel): | |||||
| return path_penalty, path_mean.detach(), path_lengths | return path_penalty, path_mean.detach(), path_lengths | ||||
| @torch.no_grad() | @torch.no_grad() | ||||
| def _evaluate_postprocess(self, src: Tensor, | |||||
| def _evaluate_postprocess(self, input: Tensor, | |||||
| target: Tensor) -> Dict[str, list]: | target: Tensor) -> Dict[str, list]: | ||||
| preds, _ = self.generator(src) | |||||
| preds, _ = self.generator(input) | |||||
| preds = list(torch.split(preds, 1, 0)) | preds = list(torch.split(preds, 1, 0)) | ||||
| targets = list(torch.split(target, 1, 0)) | targets = list(torch.split(target, 1, 0)) | ||||
| @@ -144,11 +144,11 @@ class ImagePortraitEnhancement(TorchModel): | |||||
| return {'pred': preds, 'target': targets} | return {'pred': preds, 'target': targets} | ||||
| def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor: | |||||
| self.requires_grad(self.generator, False) | self.requires_grad(self.generator, False) | ||||
| self.requires_grad(self.discriminator, True) | self.requires_grad(self.discriminator, True) | ||||
| preds, _ = self.generator(src) | |||||
| preds, _ = self.generator(input) | |||||
| fake_pred = self.discriminator(preds) | fake_pred = self.discriminator(preds) | ||||
| real_pred = self.discriminator(target) | real_pred = self.discriminator(target) | ||||
| @@ -156,27 +156,27 @@ class ImagePortraitEnhancement(TorchModel): | |||||
| return d_loss | return d_loss | ||||
| def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| src.requires_grad = True | |||||
| def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor: | |||||
| input.requires_grad = True | |||||
| target.requires_grad = True | target.requires_grad = True | ||||
| real_pred = self.discriminator(target) | real_pred = self.discriminator(target) | ||||
| r1_loss = self.d_r1_loss(real_pred, target) | r1_loss = self.d_r1_loss(real_pred, target) | ||||
| return r1_loss | return r1_loss | ||||
| def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor: | |||||
| self.requires_grad(self.generator, True) | self.requires_grad(self.generator, True) | ||||
| self.requires_grad(self.discriminator, False) | self.requires_grad(self.discriminator, False) | ||||
| preds, _ = self.generator(src) | |||||
| preds, _ = self.generator(input) | |||||
| fake_pred = self.discriminator(preds) | fake_pred = self.discriminator(preds) | ||||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) | |||||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input) | |||||
| return g_loss | return g_loss | ||||
| def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| fake_img, latents = self.generator(src, return_latents=True) | |||||
| def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor: | |||||
| fake_img, latents = self.generator(input, return_latents=True) | |||||
| path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | ||||
| fake_img, latents, self.mean_path_length) | fake_img, latents, self.mean_path_length) | ||||
| @@ -184,8 +184,8 @@ class ImagePortraitEnhancement(TorchModel): | |||||
| return path_loss | return path_loss | ||||
| @torch.no_grad() | @torch.no_grad() | ||||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||||
| return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} | |||||
| def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: | |||||
| return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)} | |||||
| def forward(self, input: Dict[str, | def forward(self, input: Dict[str, | ||||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | Tensor]) -> Dict[str, Union[list, Tensor]]: | ||||
| @@ -1,2 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from . import data, models, ops | from . import data, models, ops | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import math | import math | ||||
| import random | import random | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .model_translation import UNet | |||||
| else: | |||||
| _import_structure = { | |||||
| 'image_to_image_translation_model': ['UNet'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -1 +1,2 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .transforms import * # noqa F403 | from .transforms import * # noqa F403 | ||||
| @@ -1,2 +1,3 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .autoencoder import * # noqa F403 | from .autoencoder import * # noqa F403 | ||||
| from .clip import * # noqa F403 | from .clip import * # noqa F403 | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .degradation import * # noqa F403 | from .degradation import * # noqa F403 | ||||
| from .diffusion import * # noqa F403 | from .diffusion import * # noqa F403 | ||||
| from .losses import * # noqa F403 | from .losses import * # noqa F403 | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel): | |||||
| """ | """ | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if self.device_id == -1: | if self.device_id == -1: | ||||
| output = self.model(image) | |||||
| output = self.model(image, [text]) | |||||
| else: | else: | ||||
| device = torch.device('cuda', self.device_id) | device = torch.device('cuda', self.device_id) | ||||
| output = self.model(image.to(device), [text]) | output = self.model(image.to(device), [text]) | ||||
| @@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .tinynas_detector import Tinynas_detector | from .tinynas_detector import Tinynas_detector | ||||
| from .tinynas_damoyolo import DamoYolo | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'tinynas_detector': ['TinynasDetector'], | 'tinynas_detector': ['TinynasDetector'], | ||||
| 'tinynas_damoyolo': ['DamoYolo'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -4,6 +4,7 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from modelscope.utils.file_utils import read_file | |||||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | from ..core.base_ops import Focus, SPPBottleneck, get_activation | ||||
| from ..core.repvgg_block import RepVggBlock | from ..core.repvgg_block import RepVggBlock | ||||
| @@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): | |||||
| kernel_size, | kernel_size, | ||||
| stride, | stride, | ||||
| force_resproj=False, | force_resproj=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(ResConvK1KX, self).__init__() | super(ResConvK1KX, self).__init__() | ||||
| self.stride = stride | self.stride = stride | ||||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | ||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if not reparam: | |||||
| self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||||
| else: | |||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if act is None: | if act is None: | ||||
| self.activation_function = torch.relu | self.activation_function = torch.relu | ||||
| @@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| stride, | stride, | ||||
| num_blocks, | num_blocks, | ||||
| with_spp=False, | with_spp=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(SuperResConvK1KX, self).__init__() | super(SuperResConvK1KX, self).__init__() | ||||
| if act is None: | if act is None: | ||||
| self.act = torch.relu | self.act = torch.relu | ||||
| @@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| this_kernel_size, | this_kernel_size, | ||||
| this_stride, | this_stride, | ||||
| force_resproj, | force_resproj, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| if block_id == 0 and with_spp: | if block_id == 0 and with_spp: | ||||
| self.block_list.append( | self.block_list.append( | ||||
| @@ -248,7 +255,8 @@ class TinyNAS(nn.Module): | |||||
| with_spp=False, | with_spp=False, | ||||
| use_focus=False, | use_focus=False, | ||||
| need_conv1=True, | need_conv1=True, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(TinyNAS, self).__init__() | super(TinyNAS, self).__init__() | ||||
| assert len(out_indices) == len(out_channels) | assert len(out_indices) == len(out_channels) | ||||
| self.out_indices = out_indices | self.out_indices = out_indices | ||||
| @@ -281,7 +289,8 @@ class TinyNAS(nn.Module): | |||||
| block_info['s'], | block_info['s'], | ||||
| block_info['L'], | block_info['L'], | ||||
| spp, | spp, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| elif the_block_class == 'SuperResConvKXKX': | elif the_block_class == 'SuperResConvKXKX': | ||||
| spp = with_spp if idx == len(structure_info) - 1 else False | spp = with_spp if idx == len(structure_info) - 1 else False | ||||
| @@ -325,8 +334,8 @@ class TinyNAS(nn.Module): | |||||
| def load_tinynas_net(backbone_cfg): | def load_tinynas_net(backbone_cfg): | ||||
| # load masternet model to path | # load masternet model to path | ||||
| import ast | import ast | ||||
| struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||||
| struct_info = ast.literal_eval(struct_str) | struct_info = ast.literal_eval(struct_str) | ||||
| for layer in struct_info: | for layer in struct_info: | ||||
| if 'nbitsA' in layer: | if 'nbitsA' in layer: | ||||
| @@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): | |||||
| use_focus=backbone_cfg.use_focus, | use_focus=backbone_cfg.use_focus, | ||||
| act=backbone_cfg.act, | act=backbone_cfg.act, | ||||
| need_conv1=backbone_cfg.need_conv1, | need_conv1=backbone_cfg.need_conv1, | ||||
| ) | |||||
| reparam=backbone_cfg.reparam) | |||||
| return model | return model | ||||
| @@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| config_path = osp.join(model_dir, 'airdet_s.py') | |||||
| config_path = osp.join(model_dir, self.config_name) | |||||
| config = parse_config(config_path) | config = parse_config(config_path) | ||||
| self.cfg = config | self.cfg = config | ||||
| model_path = osp.join(model_dir, config.model.name) | model_path = osp.join(model_dir, config.model.name) | ||||
| @@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): | |||||
| self.conf_thre = config.model.head.nms_conf_thre | self.conf_thre = config.model.head.nms_conf_thre | ||||
| self.nms_thre = config.model.head.nms_iou_thre | self.nms_thre = config.model.head.nms_iou_thre | ||||
| if self.cfg.model.backbone.name == 'TinyNAS': | |||||
| self.cfg.model.backbone.structure_file = osp.join( | |||||
| model_dir, self.cfg.model.backbone.structure_file) | |||||
| self.backbone = build_backbone(self.cfg.model.backbone) | self.backbone = build_backbone(self.cfg.model.backbone) | ||||
| self.neck = build_neck(self.cfg.model.neck) | self.neck = build_neck(self.cfg.model.neck) | ||||
| self.head = build_head(self.cfg.model.head) | self.head = build_head(self.cfg.model.head) | ||||
| @@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): | |||||
| simOTA_iou_weight=3.0, | simOTA_iou_weight=3.0, | ||||
| octbase=8, | octbase=8, | ||||
| simlqe=False, | simlqe=False, | ||||
| use_lqe=True, | |||||
| **kwargs): | **kwargs): | ||||
| self.simlqe = simlqe | self.simlqe = simlqe | ||||
| self.num_classes = num_classes | self.num_classes = num_classes | ||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| self.strides = strides | self.strides = strides | ||||
| self.use_lqe = use_lqe | |||||
| self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | ||||
| else [feat_channels] * len(self.strides) | else [feat_channels] * len(self.strides) | ||||
| @@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): | |||||
| groups=self.conv_groups, | groups=self.conv_groups, | ||||
| norm=self.norm, | norm=self.norm, | ||||
| act=self.act)) | act=self.act)) | ||||
| if not self.simlqe: | |||||
| conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) | |||||
| ] | |||||
| else: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| else: | else: | ||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| reg_conf = None | |||||
| return cls_convs, reg_convs, reg_conf | return cls_convs, reg_convs, reg_conf | ||||
| @@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): | |||||
| N, C, H, W = bbox_pred.size() | N, C, H, W = bbox_pred.size() | ||||
| prob = F.softmax( | prob = F.softmax( | ||||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | ||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, | |||||
| prob_topk.mean(dim=2, keepdim=True)], | |||||
| dim=2) | |||||
| else: | |||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| else: | else: | ||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| else: | else: | ||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() | |||||
| flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | ||||
| flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | ||||
| @@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| self, | self, | ||||
| depth=1.0, | depth=1.0, | ||||
| width=1.0, | width=1.0, | ||||
| in_features=[2, 3, 4], | |||||
| in_channels=[256, 512, 1024], | in_channels=[256, 512, 1024], | ||||
| out_channels=[256, 512, 1024], | out_channels=[256, 512, 1024], | ||||
| depthwise=False, | depthwise=False, | ||||
| @@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| block_name='BasicBlock', | block_name='BasicBlock', | ||||
| ): | ): | ||||
| super().__init__() | super().__init__() | ||||
| self.in_features = in_features | |||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| Conv = DWConv if depthwise else BaseConv | Conv = DWConv if depthwise else BaseConv | ||||
| @@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): | |||||
| """ | """ | ||||
| # backbone | # backbone | ||||
| features = [out_features[f] for f in self.in_features] | |||||
| [x2, x1, x0] = features | |||||
| [x2, x1, x0] = out_features | |||||
| # node x3 | # node x3 | ||||
| x13 = self.bu_conv13(x1) | x13 = self.bu_conv13(x1) | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .detector import SingleStageDetector | |||||
| @MODELS.register_module( | |||||
| Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) | |||||
| class DamoYolo(SingleStageDetector): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| self.config_name = 'damoyolo_s.py' | |||||
| super(DamoYolo, self).__init__(model_dir, *args, **kwargs) | |||||
| @@ -12,5 +12,5 @@ from .detector import SingleStageDetector | |||||
| class TinynasDetector(SingleStageDetector): | class TinynasDetector(SingleStageDetector): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| self.config_name = 'airdet_s.py' | |||||
| super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | ||||
| @@ -161,7 +161,7 @@ def summary_format(summary, fps): | |||||
| is_summary_frame = False | is_summary_frame = False | ||||
| if is_summary_frame and summary[-1] == 1: | if is_summary_frame and summary[-1] == 1: | ||||
| end_frame = len(frame_idxes) - 1 | |||||
| end_frame = len(summary) - 1 | |||||
| frames_list.append([start_frame, end_frame]) | frames_list.append([start_frame, end_frame]) | ||||
| output = [] | output = [] | ||||
| @@ -1 +1,2 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .model import DiffusionForTextToImageSynthesis | from .model import DiffusionForTextToImageSynthesis | ||||
| @@ -1 +1,2 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .gemm_model import GEMMForMultiModalEmbedding | from .gemm_model import GEMMForMultiModalEmbedding | ||||
| @@ -1 +1,2 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .model import MultiStageDiffusionForTextToImageSynthesis | from .model import MultiStageDiffusionForTextToImageSynthesis | ||||
| @@ -3,8 +3,9 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| OFA_TASK_KEY_MAPPING = { | OFA_TASK_KEY_MAPPING = { | ||||
| Tasks.ocr_recognition: OutputKeys.TEXT, | |||||
| Tasks.image_captioning: OutputKeys.CAPTION, | Tasks.image_captioning: OutputKeys.CAPTION, | ||||
| Tasks.summarization: OutputKeys.TEXT, | |||||
| Tasks.text_summarization: OutputKeys.TEXT, | |||||
| Tasks.visual_question_answering: OutputKeys.TEXT, | Tasks.visual_question_answering: OutputKeys.TEXT, | ||||
| Tasks.visual_grounding: OutputKeys.BOXES, | Tasks.visual_grounding: OutputKeys.BOXES, | ||||
| Tasks.text_classification: OutputKeys.LABELS, | Tasks.text_classification: OutputKeys.LABELS, | ||||
| @@ -28,12 +28,13 @@ __all__ = ['OfaForAllTasks'] | |||||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.ocr_recognition, module_name=Models.ofa) | |||||
| @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | ||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.visual_question_answering, module_name=Models.ofa) | Tasks.visual_question_answering, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) | @MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.summarization, module_name=Models.ofa) | |||||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | |||||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | ||||
| class OfaForAllTasks(TorchModel): | class OfaForAllTasks(TorchModel): | ||||
| @@ -97,8 +98,9 @@ class OfaForAllTasks(TorchModel): | |||||
| 'traverse': self._traverse_inference, | 'traverse': self._traverse_inference, | ||||
| } | } | ||||
| self.task_inference_mapping = { | self.task_inference_mapping = { | ||||
| Tasks.ocr_recognition: self._text_gen_inference, | |||||
| Tasks.image_captioning: self._text_gen_inference, | Tasks.image_captioning: self._text_gen_inference, | ||||
| Tasks.summarization: self._text_gen_inference, | |||||
| Tasks.text_summarization: self._text_gen_inference, | |||||
| Tasks.visual_grounding: self._visual_grounding_inference, | Tasks.visual_grounding: self._visual_grounding_inference, | ||||
| Tasks.visual_entailment: inference_d[self.gen_type], | Tasks.visual_entailment: inference_d[self.gen_type], | ||||
| Tasks.visual_question_answering: inference_d[self.gen_type], | Tasks.visual_question_answering: inference_d[self.gen_type], | ||||
| @@ -1 +1,2 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from .team_model import TEAMForMultiModalSimilarity | from .team_model import TEAMForMultiModalSimilarity | ||||
| @@ -34,8 +34,9 @@ if TYPE_CHECKING: | |||||
| TaskModelForTextGeneration) | TaskModelForTextGeneration) | ||||
| from .token_classification import SbertForTokenClassification | from .token_classification import SbertForTokenClassification | ||||
| from .sentence_embedding import SentenceEmbedding | from .sentence_embedding import SentenceEmbedding | ||||
| from .passage_ranking import PassageRanking | |||||
| from .text_ranking import TextRanking | |||||
| from .T5 import T5ForConditionalGeneration | from .T5 import T5ForConditionalGeneration | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| @@ -75,7 +76,7 @@ else: | |||||
| 'token_classification': ['SbertForTokenClassification'], | 'token_classification': ['SbertForTokenClassification'], | ||||
| 'table_question_answering': ['TableQuestionAnswering'], | 'table_question_answering': ['TableQuestionAnswering'], | ||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'passage_ranking': ['PassageRanking'], | |||||
| 'text_ranking': ['TextRanking'], | |||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| } | } | ||||
| @@ -15,7 +15,6 @@ | |||||
| """PyTorch BERT model. """ | """PyTorch BERT model. """ | ||||
| import math | import math | ||||
| import os | |||||
| import warnings | import warnings | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Optional, Tuple | from typing import Optional, Tuple | ||||
| @@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, | |||||
| find_pruneable_heads_and_indices, | find_pruneable_heads_and_indices, | ||||
| prune_linear_layer) | prune_linear_layer) | ||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .configuration_bert import BertConfig | from .configuration_bert import BertConfig | ||||
| @@ -50,81 +48,6 @@ logger = get_logger(__name__) | |||||
| _CONFIG_FOR_DOC = 'BertConfig' | _CONFIG_FOR_DOC = 'BertConfig' | ||||
| def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |||||
| """Load tf checkpoints in a pytorch model.""" | |||||
| try: | |||||
| import re | |||||
| import numpy as np | |||||
| import tensorflow as tf | |||||
| except ImportError: | |||||
| logger.error( | |||||
| 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' | |||||
| 'https://www.tensorflow.org/install/ for installation instructions.' | |||||
| ) | |||||
| raise | |||||
| tf_path = os.path.abspath(tf_checkpoint_path) | |||||
| logger.info(f'Converting TensorFlow checkpoint from {tf_path}') | |||||
| # Load weights from TF model | |||||
| init_vars = tf.train.list_variables(tf_path) | |||||
| names = [] | |||||
| arrays = [] | |||||
| for name, shape in init_vars: | |||||
| logger.info(f'Loading TF weight {name} with shape {shape}') | |||||
| array = tf.train.load_variable(tf_path, name) | |||||
| names.append(name) | |||||
| arrays.append(array) | |||||
| for name, array in zip(names, arrays): | |||||
| name = name.split('/') | |||||
| # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |||||
| # which are not required for using pretrained model | |||||
| if any(n in [ | |||||
| 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', | |||||
| 'AdamWeightDecayOptimizer_1', 'global_step' | |||||
| ] for n in name): | |||||
| logger.info(f"Skipping {'/'.join(name)}") | |||||
| continue | |||||
| pointer = model | |||||
| for m_name in name: | |||||
| if re.fullmatch(r'[A-Za-z]+_\d+', m_name): | |||||
| scope_names = re.split(r'_(\d+)', m_name) | |||||
| else: | |||||
| scope_names = [m_name] | |||||
| if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': | |||||
| pointer = getattr(pointer, 'bias') | |||||
| elif scope_names[0] == 'output_weights': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif scope_names[0] == 'squad': | |||||
| pointer = getattr(pointer, 'classifier') | |||||
| else: | |||||
| try: | |||||
| pointer = getattr(pointer, scope_names[0]) | |||||
| except AttributeError: | |||||
| logger.info(f"Skipping {'/'.join(name)}") | |||||
| continue | |||||
| if len(scope_names) >= 2: | |||||
| num = int(scope_names[1]) | |||||
| pointer = pointer[num] | |||||
| if m_name[-11:] == '_embeddings': | |||||
| pointer = getattr(pointer, 'weight') | |||||
| elif m_name == 'kernel': | |||||
| array = np.transpose(array) | |||||
| try: | |||||
| if pointer.shape != array.shape: | |||||
| raise ValueError( | |||||
| f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' | |||||
| ) | |||||
| except AssertionError as e: | |||||
| e.args += (pointer.shape, array.shape) | |||||
| raise | |||||
| logger.info(f'Initialize PyTorch weight {name}') | |||||
| pointer.data = torch.from_numpy(array) | |||||
| return model | |||||
| class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
| """Construct the embeddings from word, position and token_type embeddings.""" | """Construct the embeddings from word, position and token_type embeddings.""" | ||||
| @@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): | |||||
| """ | """ | ||||
| config_class = BertConfig | config_class = BertConfig | ||||
| load_tf_weights = load_tf_weights_in_bert | |||||
| base_model_prefix = 'bert' | base_model_prefix = 'bert' | ||||
| supports_gradient_checkpointing = True | supports_gradient_checkpointing = True | ||||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | _keys_to_ignore_on_load_missing = [r'position_ids'] | ||||
| @@ -10,6 +10,8 @@ from modelscope.utils.constant import Tasks | |||||
| @HEADS.register_module( | @HEADS.register_module( | ||||
| Tasks.information_extraction, module_name=Heads.information_extraction) | Tasks.information_extraction, module_name=Heads.information_extraction) | ||||
| @HEADS.register_module( | |||||
| Tasks.relation_extraction, module_name=Heads.information_extraction) | |||||
| class InformationExtractionHead(TorchHead): | class InformationExtractionHead(TorchHead): | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| @@ -14,6 +14,8 @@ from modelscope.utils.constant import Tasks | |||||
| @HEADS.register_module( | @HEADS.register_module( | ||||
| Tasks.token_classification, module_name=Heads.token_classification) | Tasks.token_classification, module_name=Heads.token_classification) | ||||
| @HEADS.register_module( | |||||
| Tasks.part_of_speech, module_name=Heads.token_classification) | |||||
| class TokenClassificationHead(TorchHead): | class TokenClassificationHead(TorchHead): | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| @@ -16,6 +16,8 @@ __all__ = ['InformationExtractionModel'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.information_extraction, | Tasks.information_extraction, | ||||
| module_name=TaskModels.information_extraction) | module_name=TaskModels.information_extraction) | ||||
| @MODELS.register_module( | |||||
| Tasks.relation_extraction, module_name=TaskModels.information_extraction) | |||||
| class InformationExtractionModel(SingleBackboneTaskModelBase): | class InformationExtractionModel(SingleBackboneTaskModelBase): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -19,6 +19,8 @@ __all__ = ['TokenClassificationModel'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.token_classification, module_name=TaskModels.token_classification) | Tasks.token_classification, module_name=TaskModels.token_classification) | ||||
| @MODELS.register_module( | |||||
| Tasks.part_of_speech, module_name=TaskModels.token_classification) | |||||
| class TokenClassificationModel(SingleBackboneTaskModelBase): | class TokenClassificationModel(SingleBackboneTaskModelBase): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -13,18 +13,18 @@ from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| __all__ = ['PassageRanking'] | |||||
| __all__ = ['TextRanking'] | |||||
| @MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) | |||||
| class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||||
| @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||||
| class TextRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||||
| base_model_prefix: str = 'bert' | base_model_prefix: str = 'bert' | ||||
| supports_gradient_checkpointing = True | supports_gradient_checkpointing = True | ||||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | _keys_to_ignore_on_load_missing = [r'position_ids'] | ||||
| def __init__(self, config, model_dir, *args, **kwargs): | def __init__(self, config, model_dir, *args, **kwargs): | ||||
| if hasattr(config, 'base_model_prefix'): | if hasattr(config, 'base_model_prefix'): | ||||
| PassageRanking.base_model_prefix = config.base_model_prefix | |||||
| TextRanking.base_model_prefix = config.base_model_prefix | |||||
| super().__init__(config, model_dir) | super().__init__(config, model_dir) | ||||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | self.train_batch_size = kwargs.get('train_batch_size', 4) | ||||
| self.register_buffer( | self.register_buffer( | ||||
| @@ -74,7 +74,7 @@ class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||||
| num_labels = kwargs.get('num_labels', 1) | num_labels = kwargs.get('num_labels', 1) | ||||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | model_args = {} if num_labels is None else {'num_labels': num_labels} | ||||
| return super(SbertPreTrainedModel, PassageRanking).from_pretrained( | |||||
| return super(SbertPreTrainedModel, TextRanking).from_pretrained( | |||||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | pretrained_model_name_or_path=kwargs.get('model_dir'), | ||||
| model_dir=kwargs.get('model_dir'), | model_dir=kwargs.get('model_dir'), | ||||
| **model_args) | **model_args) | ||||
| @@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): | |||||
| if self.split_config is not None: | if self.split_config is not None: | ||||
| self._update_data_source(kwargs['data_source']) | self._update_data_source(kwargs['data_source']) | ||||
| def _update_data_root(self, input_dict, data_root): | |||||
| for k, v in input_dict.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| input_dict.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| elif isinstance(v, dict): | |||||
| self._update_data_root(v, data_root) | |||||
| def _update_data_source(self, data_source): | def _update_data_source(self, data_source): | ||||
| data_root = next(iter(self.split_config.values())) | data_root = next(iter(self.split_config.values())) | ||||
| data_root = data_root.rstrip(osp.sep) | data_root = data_root.rstrip(osp.sep) | ||||
| for k, v in data_source.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| data_source.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| self._update_data_root(data_source, data_root) | |||||
| @@ -12,14 +12,14 @@ if TYPE_CHECKING: | |||||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | from .movie_scene_segmentation import MovieSceneSegmentationDataset | ||||
| from .video_summarization_dataset import VideoSummarizationDataset | from .video_summarization_dataset import VideoSummarizationDataset | ||||
| from .image_inpainting import ImageInpaintingDataset | from .image_inpainting import ImageInpaintingDataset | ||||
| from .passage_ranking_dataset import PassageRankingDataset | |||||
| from .text_ranking_dataset import TextRankingDataset | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'base': ['TaskDataset'], | 'base': ['TaskDataset'], | ||||
| 'builder': ['TASK_DATASETS', 'build_task_dataset'], | 'builder': ['TASK_DATASETS', 'build_task_dataset'], | ||||
| 'torch_base_dataset': ['TorchTaskDataset'], | 'torch_base_dataset': ['TorchTaskDataset'], | ||||
| 'passage_ranking_dataset': ['PassageRankingDataset'], | |||||
| 'text_ranking_dataset': ['TextRankingDataset'], | |||||
| 'veco_dataset': ['VecoDataset'], | 'veco_dataset': ['VecoDataset'], | ||||
| 'image_instance_segmentation_coco_dataset': | 'image_instance_segmentation_coco_dataset': | ||||
| ['ImageInstanceSegmentationCocoDataset'], | ['ImageInstanceSegmentationCocoDataset'], | ||||
| @@ -27,6 +27,8 @@ else: | |||||
| 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | ||||
| 'image_inpainting': ['ImageInpaintingDataset'], | 'image_inpainting': ['ImageInpaintingDataset'], | ||||
| 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | ||||
| 'image_portrait_enhancement_dataset': | |||||
| ['ImagePortraitEnhancementDataset'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset | |||||
| else: | |||||
| _import_structure = { | |||||
| 'image_portrait_enhancement_dataset': | |||||
| ['ImagePortraitEnhancementDataset'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,32 @@ | |||||
| # ------------------------------------------------------------------------ | |||||
| # Modified from BasicSR (https://github.com/xinntao/BasicSR) | |||||
| # Copyright 2018-2020 BasicSR Authors | |||||
| # ------------------------------------------------------------------------ | |||||
| import cv2 | |||||
| import torch | |||||
| def img2tensor(imgs, bgr2rgb=True, float32=True): | |||||
| """Numpy array to tensor. | |||||
| Args: | |||||
| imgs (list[ndarray] | ndarray): Input images. | |||||
| bgr2rgb (bool): Whether to change bgr to rgb. | |||||
| float32 (bool): Whether to change to float32. | |||||
| Returns: | |||||
| list[tensor] | tensor: Tensor images. If returned results only have | |||||
| one element, just return tensor. | |||||
| """ | |||||
| def _totensor(img, bgr2rgb, float32): | |||||
| if img.shape[2] == 3 and bgr2rgb: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| img = torch.from_numpy(img.transpose(2, 0, 1)) | |||||
| if float32: | |||||
| img = img.float() | |||||
| return img | |||||
| if isinstance(imgs, list): | |||||
| return [_totensor(img, bgr2rgb, float32) for img in imgs] | |||||
| else: | |||||
| return _totensor(imgs, bgr2rgb, float32) | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Datasets, Models | |||||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
| TorchTaskDataset | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .data_utils import img2tensor | |||||
| def default_loader(path): | |||||
| return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 | |||||
| @TASK_DATASETS.register_module( | |||||
| Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) | |||||
| class ImagePortraitEnhancementDataset(TorchTaskDataset): | |||||
| """Paired image dataset for image portrait enhancement. | |||||
| """ | |||||
| def __init__(self, dataset, is_train): | |||||
| self.dataset = dataset | |||||
| self.gt_size = 256 | |||||
| self.is_train = is_train | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| def __getitem__(self, index): | |||||
| # Load gt and lq images. Dimension order: HWC; channel order: BGR; | |||||
| # image range: [0, 1], float32. | |||||
| item_dict = self.dataset[index] | |||||
| gt_path = item_dict['hq:FILE'] | |||||
| img_gt = default_loader(gt_path) | |||||
| lq_path = item_dict['lq:FILE'] | |||||
| img_lq = default_loader(lq_path) | |||||
| gt_size = self.gt_size | |||||
| img_gt = cv2.resize(img_gt, (gt_size, gt_size)) | |||||
| img_lq = cv2.resize(img_lq, (gt_size, gt_size)) | |||||
| # BGR to RGB, HWC to CHW, numpy to tensor | |||||
| img_gt, img_lq = img2tensor([img_gt, img_lq], | |||||
| bgr2rgb=True, | |||||
| float32=True) | |||||
| return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5} | |||||
| @@ -16,8 +16,8 @@ from .torch_base_dataset import TorchTaskDataset | |||||
| @TASK_DATASETS.register_module( | @TASK_DATASETS.register_module( | ||||
| group_key=Tasks.passage_ranking, module_name=Models.bert) | |||||
| class PassageRankingDataset(TorchTaskDataset): | |||||
| group_key=Tasks.text_ranking, module_name=Models.bert) | |||||
| class TextRankingDataset(TorchTaskDataset): | |||||
| def __init__(self, | def __init__(self, | ||||
| datasets: Union[Any, List[Any]], | datasets: Union[Any, List[Any]], | ||||
| @@ -35,8 +35,8 @@ class PassageRankingDataset(TorchTaskDataset): | |||||
| 'positive_passages') | 'positive_passages') | ||||
| self.neg_sequence = self.dataset_config.get('neg_sequence', | self.neg_sequence = self.dataset_config.get('neg_sequence', | ||||
| 'negative_passages') | 'negative_passages') | ||||
| self.passage_text_fileds = self.dataset_config.get( | |||||
| 'passage_text_fileds', ['title', 'text']) | |||||
| self.text_fileds = self.dataset_config.get('text_fileds', | |||||
| ['title', 'text']) | |||||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | self.qid_field = self.dataset_config.get('qid_field', 'query_id') | ||||
| if mode == ModeKeys.TRAIN: | if mode == ModeKeys.TRAIN: | ||||
| train_config = kwargs.get('train', {}) | train_config = kwargs.get('train', {}) | ||||
| @@ -58,14 +58,14 @@ class PassageRankingDataset(TorchTaskDataset): | |||||
| pos_sequences = group[self.pos_sequence] | pos_sequences = group[self.pos_sequence] | ||||
| pos_sequences = [ | pos_sequences = [ | ||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| ' '.join([ele[key] for key in self.text_fileds]) | |||||
| for ele in pos_sequences | for ele in pos_sequences | ||||
| ] | ] | ||||
| labels.extend([1] * len(pos_sequences)) | labels.extend([1] * len(pos_sequences)) | ||||
| neg_sequences = group[self.neg_sequence] | neg_sequences = group[self.neg_sequence] | ||||
| neg_sequences = [ | neg_sequences = [ | ||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| ' '.join([ele[key] for key in self.text_fileds]) | |||||
| for ele in neg_sequences | for ele in neg_sequences | ||||
| ] | ] | ||||
| @@ -88,13 +88,13 @@ class PassageRankingDataset(TorchTaskDataset): | |||||
| pos_sequences = group[self.pos_sequence] | pos_sequences = group[self.pos_sequence] | ||||
| pos_sequences = [ | pos_sequences = [ | ||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| ' '.join([ele[key] for key in self.text_fileds]) | |||||
| for ele in pos_sequences | for ele in pos_sequences | ||||
| ] | ] | ||||
| neg_sequences = group[self.neg_sequence] | neg_sequences = group[self.neg_sequence] | ||||
| neg_sequences = [ | neg_sequences = [ | ||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| ' '.join([ele[key] for key in self.text_fileds]) | |||||
| for ele in neg_sequences | for ele in neg_sequences | ||||
| ] | ] | ||||
| @@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union | |||||
| from datasets.builder import DatasetBuilder | from datasets.builder import DatasetBuilder | ||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams | |||||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | ||||
| @@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||||
| res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | ||||
| """ | """ | ||||
| res = [] | res = [] | ||||
| cookies = hub_api.check_cookies_upload_data(use_cookies=True) | |||||
| objects = hub_api.list_oss_dataset_objects( | objects = hub_api.list_oss_dataset_objects( | ||||
| dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
| namespace=namespace, | namespace=namespace, | ||||
| max_limit=max_limit, | max_limit=max_limit, | ||||
| is_recursive=is_recursive, | is_recursive=is_recursive, | ||||
| is_filter_dir=True, | is_filter_dir=True, | ||||
| revision=version, | |||||
| cookies=cookies) | |||||
| revision=version) | |||||
| for item in objects: | for item in objects: | ||||
| object_key = item.get('Key') | object_key = item.get('Key') | ||||
| @@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict, | |||||
| modelscope_api = HubApi() | modelscope_api = HubApi() | ||||
| objects = list_dataset_objects( | objects = list_dataset_objects( | ||||
| hub_api=modelscope_api, | hub_api=modelscope_api, | ||||
| max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value, | |||||
| max_limit=-1, | |||||
| is_recursive=True, | is_recursive=True, | ||||
| dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
| namespace=namespace, | namespace=namespace, | ||||
| @@ -506,7 +506,7 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | Tasks.text_error_correction: [OutputKeys.OUTPUT], | ||||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | ||||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||||
| Tasks.text_ranking: [OutputKeys.SCORES], | |||||
| # text generation result for single sample | # text generation result for single sample | ||||
| # { | # { | ||||
| @@ -661,6 +661,7 @@ TASK_OUTPUTS = { | |||||
| # "caption": "this is an image caption text." | # "caption": "this is an image caption text." | ||||
| # } | # } | ||||
| Tasks.image_captioning: [OutputKeys.CAPTION], | Tasks.image_captioning: [OutputKeys.CAPTION], | ||||
| Tasks.ocr_recognition: [OutputKeys.TEXT], | |||||
| # visual grounding result for single sample | # visual grounding result for single sample | ||||
| # { | # { | ||||
| @@ -162,7 +162,7 @@ TASK_INPUTS = { | |||||
| 'source_sentence': InputType.LIST, | 'source_sentence': InputType.LIST, | ||||
| 'sentences_to_compare': InputType.LIST, | 'sentences_to_compare': InputType.LIST, | ||||
| }, | }, | ||||
| Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), | |||||
| Tasks.text_ranking: (InputType.TEXT, InputType.TEXT), | |||||
| Tasks.text_generation: | Tasks.text_generation: | ||||
| InputType.TEXT, | InputType.TEXT, | ||||
| Tasks.fill_mask: | Tasks.fill_mask: | ||||
| @@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| if isinstance(audio_in, str): | if isinstance(audio_in, str): | ||||
| # load pcm data from url if audio_in is url str | # load pcm data from url if audio_in is url str | ||||
| self.audio_in = load_bytes_from_url(audio_in) | |||||
| self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | elif isinstance(audio_in, bytes): | ||||
| # load pcm data from wav data if audio_in is wave format | # load pcm data from wav data if audio_in is wave format | ||||
| self.audio_in = extract_pcm_from_wav(audio_in) | |||||
| self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) | |||||
| else: | else: | ||||
| self.audio_in = audio_in | self.audio_in = audio_in | ||||
| # set the sample_rate of audio_in if checking_audio_fs is valid | |||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| if recog_type is None or audio_format is None: | if recog_type is None or audio_format is None: | ||||
| self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( | ||||
| audio_in=self.audio_in, | audio_in=self.audio_in, | ||||
| recog_type=recog_type, | recog_type=recog_type, | ||||
| audio_format=audio_format) | audio_format=audio_format) | ||||
| if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None: | |||||
| self.audio_fs = asr_utils.sample_rate_checking( | |||||
| if hasattr(asr_utils, 'sample_rate_checking'): | |||||
| checking_audio_fs = asr_utils.sample_rate_checking( | |||||
| self.audio_in, self.audio_format) | self.audio_in, self.audio_format) | ||||
| if checking_audio_fs is not None: | |||||
| self.audio_fs = checking_audio_fs | |||||
| if self.preprocessor is None: | if self.preprocessor is None: | ||||
| self.preprocessor = WavToScp() | self.preprocessor = WavToScp() | ||||
| @@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| logger.info(f"Decoding with {inputs['audio_format']} files ...") | logger.info(f"Decoding with {inputs['audio_format']} files ...") | ||||
| data_cmd: Sequence[Tuple[str, str]] | |||||
| data_cmd: Sequence[Tuple[str, str, str]] | |||||
| if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': | ||||
| data_cmd = ['speech', 'sound'] | data_cmd = ['speech', 'sound'] | ||||
| elif inputs['audio_format'] == 'kaldi_ark': | elif inputs['audio_format'] == 'kaldi_ark': | ||||
| @@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): | |||||
| elif inputs['audio_format'] == 'tfrecord': | elif inputs['audio_format'] == 'tfrecord': | ||||
| data_cmd = ['speech', 'tfrecord'] | data_cmd = ['speech', 'tfrecord'] | ||||
| if inputs.__contains__('mvn_file'): | |||||
| data_cmd.append(inputs['mvn_file']) | |||||
| # generate asr inference command | # generate asr inference command | ||||
| cmd = { | cmd = { | ||||
| 'model_type': inputs['model_type'], | 'model_type': inputs['model_type'], | ||||
| @@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| if isinstance(audio_in, str): | if isinstance(audio_in, str): | ||||
| # load pcm data from url if audio_in is url str | # load pcm data from url if audio_in is url str | ||||
| audio_in = load_bytes_from_url(audio_in) | |||||
| audio_in, audio_fs = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | elif isinstance(audio_in, bytes): | ||||
| # load pcm data from wav data if audio_in is wave format | # load pcm data from wav data if audio_in is wave format | ||||
| audio_in = extract_pcm_from_wav(audio_in) | |||||
| audio_in, audio_fs = extract_pcm_from_wav(audio_in) | |||||
| output = self.preprocessor.forward(self.model.forward(), audio_in) | output = self.preprocessor.forward(self.model.forward(), audio_in) | ||||
| output = self.forward(output) | output = self.forward(output) | ||||
| @@ -433,6 +433,8 @@ def collate_fn(data, device): | |||||
| if isinstance(data, dict) or isinstance(data, Mapping): | if isinstance(data, dict) or isinstance(data, Mapping): | ||||
| return type(data)({k: collate_fn(v, device) for k, v in data.items()}) | return type(data)({k: collate_fn(v, device) for k, v in data.items()}) | ||||
| elif isinstance(data, (tuple, list)): | elif isinstance(data, (tuple, list)): | ||||
| if 0 == len(data): | |||||
| return torch.Tensor([]) | |||||
| if isinstance(data[0], (int, float)): | if isinstance(data[0], (int, float)): | ||||
| return default_collate(data).to(device) | return default_collate(data).to(device) | ||||
| else: | else: | ||||
| @@ -20,17 +20,22 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.sentence_embedding: | Tasks.sentence_embedding: | ||||
| (Pipelines.sentence_embedding, | (Pipelines.sentence_embedding, | ||||
| 'damo/nlp_corom_sentence-embedding_english-base'), | 'damo/nlp_corom_sentence-embedding_english-base'), | ||||
| Tasks.passage_ranking: (Pipelines.passage_ranking, | |||||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||||
| Tasks.text_ranking: (Pipelines.text_ranking, | |||||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||||
| Tasks.word_segmentation: | Tasks.word_segmentation: | ||||
| (Pipelines.word_segmentation, | (Pipelines.word_segmentation, | ||||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | 'damo/nlp_structbert_word-segmentation_chinese-base'), | ||||
| Tasks.part_of_speech: (Pipelines.part_of_speech, | |||||
| 'damo/nlp_structbert_part-of-speech_chinese-base'), | |||||
| Tasks.token_classification: | Tasks.token_classification: | ||||
| (Pipelines.part_of_speech, | (Pipelines.part_of_speech, | ||||
| 'damo/nlp_structbert_part-of-speech_chinese-base'), | 'damo/nlp_structbert_part-of-speech_chinese-base'), | ||||
| Tasks.named_entity_recognition: | Tasks.named_entity_recognition: | ||||
| (Pipelines.named_entity_recognition, | (Pipelines.named_entity_recognition, | ||||
| 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | ||||
| Tasks.relation_extraction: | |||||
| (Pipelines.relation_extraction, | |||||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | |||||
| Tasks.information_extraction: | Tasks.information_extraction: | ||||
| (Pipelines.relation_extraction, | (Pipelines.relation_extraction, | ||||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | 'damo/nlp_bert_relation-extraction_chinese-base'), | ||||
| @@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints | max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints | ||||
| for i, frame in enumerate(video_frames): | for i, frame in enumerate(video_frames): | ||||
| kps_2d = self.human_body_2d_kps_detector(frame) | kps_2d = self.human_body_2d_kps_detector(frame) | ||||
| if [] == kps_2d.get('boxes'): | |||||
| res = { | |||||
| 'success': False, | |||||
| 'msg': f'fail to detect person at image frame {i}' | |||||
| } | |||||
| return res | |||||
| box = kps_2d['boxes'][ | box = kps_2d['boxes'][ | ||||
| 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox | 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox | ||||
| pose = kps_2d['keypoints'][0] # keypoints: [15, 2] | pose = kps_2d['keypoints'][0] # keypoints: [15, 2] | ||||
| @@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| return res | return res | ||||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||||
| res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []} | |||||
| output_video_path = kwargs.get('output_video', None) | |||||
| if output_video_path is None: | |||||
| output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name | |||||
| res = { | |||||
| OutputKeys.KEYPOINTS: [], | |||||
| OutputKeys.TIMESTAMPS: [], | |||||
| OutputKeys.OUTPUT_VIDEO: output_video_path | |||||
| } | |||||
| if not input['success']: | if not input['success']: | ||||
| pass | pass | ||||
| @@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| pred_3d_pose = poses.data.cpu().numpy()[ | pred_3d_pose = poses.data.cpu().numpy()[ | ||||
| 0] # [frame_num, joint_num, joint_dim] | 0] # [frame_num, joint_num, joint_dim] | ||||
| output_video_path = kwargs.get('output_video', None) | |||||
| if output_video_path is None: | |||||
| output_video_path = tempfile.NamedTemporaryFile( | |||||
| suffix='.mp4').name | |||||
| if 'render' in self.keypoint_model_3d.cfg.keys(): | if 'render' in self.keypoint_model_3d.cfg.keys(): | ||||
| self.render_prediction(pred_3d_pose, output_video_path) | self.render_prediction(pred_3d_pose, output_video_path) | ||||
| res[OutputKeys.OUTPUT_VIDEO] = output_video_path | res[OutputKeys.OUTPUT_VIDEO] = output_video_path | ||||
| @@ -61,6 +61,8 @@ class FaceImageGenerationPipeline(Pipeline): | |||||
| return input | return input | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| input = int(input) | |||||
| assert isinstance(input, int) | assert isinstance(input, int) | ||||
| torch.manual_seed(input) | torch.manual_seed(input) | ||||
| torch.cuda.manual_seed(input) | torch.cuda.manual_seed(input) | ||||
| @@ -53,6 +53,7 @@ class ImageReidPersonPipeline(Pipeline): | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| img = input['img'] | img = input['img'] | ||||
| img_embedding = self.model(img) | img_embedding = self.model(img) | ||||
| img_embedding = img_embedding.detach().cpu().numpy() | |||||
| return {OutputKeys.IMG_EMBEDDING: img_embedding} | return {OutputKeys.IMG_EMBEDDING: img_embedding} | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.cv.image_utils import \ | |||||
| show_image_object_detection_auto_result | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): | |||||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | bboxes, scores, labels = self.model.postprocess(inputs['data']) | ||||
| if bboxes is None: | if bboxes is None: | ||||
| return None | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| outputs = { | |||||
| OutputKeys.SCORES: [], | |||||
| OutputKeys.LABELS: [], | |||||
| OutputKeys.BOXES: [] | |||||
| } | |||||
| else: | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| return outputs | return outputs | ||||
| def show_result(self, img_path, result, save_path=None): | |||||
| show_image_object_detection_auto_result(img_path, result, save_path) | |||||
| @@ -11,6 +11,8 @@ from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @PIPELINES.register_module( | |||||
| Tasks.image_text_retrieval, module_name=Pipelines.multi_modal_embedding) | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | ||||
| class MultiModalEmbeddingPipeline(Pipeline): | class MultiModalEmbeddingPipeline(Pipeline): | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.multi_modal import OfaForAllTasks | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) | |||||
| class OcrRecognitionPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create a ocr recognition pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None: | |||||
| if isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -17,7 +17,7 @@ if TYPE_CHECKING: | |||||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | ||||
| from .information_extraction_pipeline import InformationExtractionPipeline | from .information_extraction_pipeline import InformationExtractionPipeline | ||||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | ||||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||||
| from .text_ranking_pipeline import TextRankingPipeline | |||||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | ||||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | from .sequence_classification_pipeline import SequenceClassificationPipeline | ||||
| from .summarization_pipeline import SummarizationPipeline | from .summarization_pipeline import SummarizationPipeline | ||||
| @@ -51,7 +51,7 @@ else: | |||||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | 'information_extraction_pipeline': ['InformationExtractionPipeline'], | ||||
| 'named_entity_recognition_pipeline': | 'named_entity_recognition_pipeline': | ||||
| ['NamedEntityRecognitionPipeline'], | ['NamedEntityRecognitionPipeline'], | ||||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||||
| 'text_ranking_pipeline': ['TextRankingPipeline'], | |||||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | ||||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | ||||
| 'summarization_pipeline': ['SummarizationPipeline'], | 'summarization_pipeline': ['SummarizationPipeline'], | ||||
| @@ -17,6 +17,8 @@ __all__ = ['InformationExtractionPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.information_extraction, module_name=Pipelines.relation_extraction) | Tasks.information_extraction, module_name=Pipelines.relation_extraction) | ||||
| @PIPELINES.register_module( | |||||
| Tasks.relation_extraction, module_name=Pipelines.relation_extraction) | |||||
| class InformationExtractionPipeline(Pipeline): | class InformationExtractionPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -13,7 +13,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.summarization, module_name=Pipelines.text_generation) | |||||
| Tasks.text_summarization, module_name=Pipelines.text_generation) | |||||
| class SummarizationPipeline(Pipeline): | class SummarizationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| action = self.action_ops[result['action']] | action = self.action_ops[result['action']] | ||||
| headers = table['header_name'] | headers = table['header_name'] | ||||
| current_sql = result['sql'] | current_sql = result['sql'] | ||||
| current_sql['from'] = [table['table_id']] | |||||
| if history_sql is None: | if history_sql is None: | ||||
| return current_sql | return current_sql | ||||
| @@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| else: | else: | ||||
| return current_sql | return current_sql | ||||
| def sql_dict_to_str(self, result, table): | |||||
| def sql_dict_to_str(self, result, tables): | |||||
| """ | """ | ||||
| convert sql struct to string | convert sql struct to string | ||||
| """ | """ | ||||
| table = tables[result['sql']['from'][0]] | |||||
| header_names = table['header_name'] + ['空列'] | header_names = table['header_name'] + ['空列'] | ||||
| header_ids = table['header_id'] + ['null'] | header_ids = table['header_id'] + ['null'] | ||||
| sql = result['sql'] | sql = result['sql'] | ||||
| @@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| """ | """ | ||||
| result = inputs['result'] | result = inputs['result'] | ||||
| history_sql = inputs['history_sql'] | history_sql = inputs['history_sql'] | ||||
| result['sql'] = self.post_process_multi_turn( | |||||
| history_sql=history_sql, | |||||
| result=result, | |||||
| table=self.db.tables[result['table_id']]) | |||||
| result['sql']['from'] = [result['table_id']] | |||||
| sql = self.sql_dict_to_str( | |||||
| result=result, table=self.db.tables[result['table_id']]) | |||||
| try: | |||||
| result['sql'] = self.post_process_multi_turn( | |||||
| history_sql=history_sql, | |||||
| result=result, | |||||
| table=self.db.tables[result['table_id']]) | |||||
| except Exception: | |||||
| result['sql'] = history_sql | |||||
| sql = self.sql_dict_to_str(result=result, tables=self.db.tables) | |||||
| # add sqlite | # add sqlite | ||||
| if self.db.is_use_sqlite: | if self.db.is_use_sqlite: | ||||
| try: | try: | ||||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | cursor = self.db.connection_obj.cursor().execute(sql.query) | ||||
| names = [{ | |||||
| 'name': | |||||
| description[0], | |||||
| 'label': | |||||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||||
| description[0], description[0]) | |||||
| } for description in cursor.description] | |||||
| cells = [] | |||||
| header_ids, header_names = [], [] | |||||
| for description in cursor.description: | |||||
| header_ids.append(self.db.tables[result['table_id']] | |||||
| ['headerid2name'].get( | |||||
| description[0], description[0])) | |||||
| header_names.append(description[0]) | |||||
| rows = [] | |||||
| for res in cursor.fetchall(): | for res in cursor.fetchall(): | ||||
| row = {} | |||||
| for name, cell in zip(names, res): | |||||
| row[name['name']] = cell | |||||
| cells.append(row) | |||||
| tabledata = {'headers': names, 'cells': cells} | |||||
| rows.append(list(res)) | |||||
| tabledata = { | |||||
| 'header_id': header_ids, | |||||
| 'header_name': header_names, | |||||
| 'rows': rows | |||||
| } | |||||
| except Exception: | except Exception: | ||||
| tabledata = {'headers': [], 'cells': []} | |||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||||
| else: | else: | ||||
| tabledata = {'headers': [], 'cells': []} | |||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||||
| output = { | output = { | ||||
| OutputKeys.SQL_STRING: sql.string, | OutputKeys.SQL_STRING: sql.string, | ||||
| OutputKeys.SQL_QUERY: sql.query, | OutputKeys.SQL_QUERY: sql.query, | ||||
| OutputKeys.HISTORY: result['sql'], | OutputKeys.HISTORY: result['sql'], | ||||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||||
| OutputKeys.QUERT_RESULT: tabledata, | |||||
| } | } | ||||
| return output | return output | ||||
| @@ -9,15 +9,15 @@ from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor | |||||
| from modelscope.preprocessors import Preprocessor, TextRankingPreprocessor | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| __all__ = ['PassageRankingPipeline'] | |||||
| __all__ = ['TextRankingPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.passage_ranking, module_name=Pipelines.passage_ranking) | |||||
| class PassageRankingPipeline(Pipeline): | |||||
| Tasks.text_ranking, module_name=Pipelines.text_ranking) | |||||
| class TextRankingPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| @@ -36,7 +36,7 @@ class PassageRankingPipeline(Pipeline): | |||||
| Model) else Model.from_pretrained(model) | Model) else Model.from_pretrained(model) | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = PassageRankingPreprocessor( | |||||
| preprocessor = TextRankingPreprocessor( | |||||
| model.model_dir if isinstance(model, Model) else model, | model.model_dir if isinstance(model, Model) else model, | ||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | model.eval() | ||||
| @@ -18,6 +18,8 @@ __all__ = ['TokenClassificationPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.token_classification, module_name=Pipelines.part_of_speech) | Tasks.token_classification, module_name=Pipelines.part_of_speech) | ||||
| @PIPELINES.register_module( | |||||
| Tasks.part_of_speech, module_name=Pipelines.part_of_speech) | |||||
| class TokenClassificationPipeline(Pipeline): | class TokenClassificationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -21,7 +21,7 @@ if TYPE_CHECKING: | |||||
| FillMaskPoNetPreprocessor, | FillMaskPoNetPreprocessor, | ||||
| NLPPreprocessor, | NLPPreprocessor, | ||||
| NLPTokenizerPreprocessorBase, | NLPTokenizerPreprocessorBase, | ||||
| PassageRankingPreprocessor, | |||||
| TextRankingPreprocessor, | |||||
| RelationExtractionPreprocessor, | RelationExtractionPreprocessor, | ||||
| SentenceEmbeddingPreprocessor, | SentenceEmbeddingPreprocessor, | ||||
| SequenceClassificationPreprocessor, | SequenceClassificationPreprocessor, | ||||
| @@ -62,7 +62,7 @@ else: | |||||
| 'FillMaskPoNetPreprocessor', | 'FillMaskPoNetPreprocessor', | ||||
| 'NLPPreprocessor', | 'NLPPreprocessor', | ||||
| 'NLPTokenizerPreprocessorBase', | 'NLPTokenizerPreprocessorBase', | ||||
| 'PassageRankingPreprocessor', | |||||
| 'TextRankingPreprocessor', | |||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', | 'SentenceEmbeddingPreprocessor', | ||||
| 'SequenceClassificationPreprocessor', | 'SequenceClassificationPreprocessor', | ||||
| @@ -133,6 +133,12 @@ class WavToScp(Preprocessor): | |||||
| else: | else: | ||||
| inputs['asr_model_config'] = asr_model_config | inputs['asr_model_config'] = asr_model_config | ||||
| if inputs['model_config'].__contains__('mvn_file'): | |||||
| mvn_file = os.path.join(inputs['model_workspace'], | |||||
| inputs['model_config']['mvn_file']) | |||||
| assert os.path.exists(mvn_file), 'mvn_file does not exist' | |||||
| inputs['mvn_file'] = mvn_file | |||||
| elif inputs['model_type'] == Frameworks.tf: | elif inputs['model_type'] == Frameworks.tf: | ||||
| assert inputs['model_config'].__contains__( | assert inputs['model_config'].__contains__( | ||||
| 'vocab_file'), 'vocab_file does not exist' | 'vocab_file'), 'vocab_file does not exist' | ||||
| @@ -16,6 +16,7 @@ from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| from .ofa import * # noqa | from .ofa import * # noqa | ||||
| from .ofa.utils.collate import collate_fn | from .ofa.utils.collate import collate_fn | ||||
| from .ofa.utils.constant import OFA_TASK_KEY_MAPPING | |||||
| __all__ = [ | __all__ = [ | ||||
| 'OfaPreprocessor', | 'OfaPreprocessor', | ||||
| @@ -40,6 +41,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| """ | """ | ||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| preprocess_mapping = { | preprocess_mapping = { | ||||
| Tasks.ocr_recognition: OfaOcrRecognitionPreprocessor, | |||||
| Tasks.image_captioning: OfaImageCaptioningPreprocessor, | Tasks.image_captioning: OfaImageCaptioningPreprocessor, | ||||
| Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | ||||
| Tasks.visual_question_answering: | Tasks.visual_question_answering: | ||||
| @@ -47,26 +49,16 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | ||||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | Tasks.image_classification: OfaImageClassificationPreprocessor, | ||||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | Tasks.text_classification: OfaTextClassificationPreprocessor, | ||||
| Tasks.summarization: OfaSummarizationPreprocessor, | |||||
| Tasks.text_summarization: OfaSummarizationPreprocessor, | |||||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | ||||
| } | } | ||||
| input_key_mapping = { | |||||
| Tasks.image_captioning: ['image'], | |||||
| Tasks.image_classification: ['image'], | |||||
| Tasks.summarization: ['text'], | |||||
| Tasks.text_classification: ['text', 'text2'], | |||||
| Tasks.visual_grounding: ['image', 'text'], | |||||
| Tasks.visual_question_answering: ['image', 'text'], | |||||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | |||||
| Tasks.text_to_image_synthesis: ['text'] | |||||
| } | |||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | model_dir) | ||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
| self.preprocess = preprocess_mapping[self.cfg.task]( | self.preprocess = preprocess_mapping[self.cfg.task]( | ||||
| cfg=self.cfg, model_dir=model_dir, mode=mode) | cfg=self.cfg, model_dir=model_dir, mode=mode) | ||||
| self.keys = input_key_mapping[self.cfg.task] | |||||
| self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task] | |||||
| self.tokenizer = self.preprocess.tokenizer | self.tokenizer = self.preprocess.tokenizer | ||||
| if kwargs.get('no_collate', None): | if kwargs.get('no_collate', None): | ||||
| self.no_collate = True | self.no_collate = True | ||||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||||
| FillMaskPoNetPreprocessor, | FillMaskPoNetPreprocessor, | ||||
| NLPPreprocessor, | NLPPreprocessor, | ||||
| NLPTokenizerPreprocessorBase, | NLPTokenizerPreprocessorBase, | ||||
| PassageRankingPreprocessor, | |||||
| TextRankingPreprocessor, | |||||
| RelationExtractionPreprocessor, | RelationExtractionPreprocessor, | ||||
| SentenceEmbeddingPreprocessor, | SentenceEmbeddingPreprocessor, | ||||
| SequenceClassificationPreprocessor, | SequenceClassificationPreprocessor, | ||||
| @@ -33,7 +33,7 @@ else: | |||||
| 'FillMaskPoNetPreprocessor', | 'FillMaskPoNetPreprocessor', | ||||
| 'NLPPreprocessor', | 'NLPPreprocessor', | ||||
| 'NLPTokenizerPreprocessorBase', | 'NLPTokenizerPreprocessorBase', | ||||
| 'PassageRankingPreprocessor', | |||||
| 'TextRankingPreprocessor', | |||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', | 'SentenceEmbeddingPreprocessor', | ||||
| 'SequenceClassificationPreprocessor', | 'SequenceClassificationPreprocessor', | ||||
| @@ -1,9 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import re | import re | ||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||||
| from typing import Any, Dict, Optional, Tuple, Union | |||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| import sentencepiece as spm | import sentencepiece as spm | ||||
| import torch | import torch | ||||
| @@ -13,8 +14,7 @@ from modelscope.metainfo import Models, Preprocessors | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
| from modelscope.preprocessors.builder import PREPROCESSORS | from modelscope.preprocessors.builder import PREPROCESSORS | ||||
| from modelscope.utils.config import (Config, ConfigFields, | |||||
| use_task_specific_params) | |||||
| from modelscope.utils.config import Config, ConfigFields | |||||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | ||||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | from modelscope.utils.hub import get_model_type, parse_label_mapping | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -29,7 +29,7 @@ __all__ = [ | |||||
| 'NLPPreprocessor', | 'NLPPreprocessor', | ||||
| 'FillMaskPoNetPreprocessor', | 'FillMaskPoNetPreprocessor', | ||||
| 'NLPTokenizerPreprocessorBase', | 'NLPTokenizerPreprocessorBase', | ||||
| 'PassageRankingPreprocessor', | |||||
| 'TextRankingPreprocessor', | |||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', | 'SentenceEmbeddingPreprocessor', | ||||
| 'SequenceClassificationPreprocessor', | 'SequenceClassificationPreprocessor', | ||||
| @@ -83,6 +83,15 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| self._mode = mode | self._mode = mode | ||||
| self.label = kwargs.pop('label', OutputKeys.LABEL) | self.label = kwargs.pop('label', OutputKeys.LABEL) | ||||
| self.use_fast = kwargs.pop('use_fast', None) | |||||
| if self.use_fast is None and os.path.isfile( | |||||
| os.path.join(model_dir, 'tokenizer_config.json')): | |||||
| with open(os.path.join(model_dir, 'tokenizer_config.json'), | |||||
| 'r') as f: | |||||
| json_config = json.load(f) | |||||
| self.use_fast = json_config.get('use_fast') | |||||
| self.use_fast = False if self.use_fast is None else self.use_fast | |||||
| self.label2id = None | self.label2id = None | ||||
| if 'label2id' in kwargs: | if 'label2id' in kwargs: | ||||
| self.label2id = kwargs.pop('label2id') | self.label2id = kwargs.pop('label2id') | ||||
| @@ -118,32 +127,23 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| if model_type in (Models.structbert, Models.gpt3, Models.palm, | if model_type in (Models.structbert, Models.gpt3, Models.palm, | ||||
| Models.plug): | Models.plug): | ||||
| from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast | from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast | ||||
| return SbertTokenizer.from_pretrained( | |||||
| model_dir | |||||
| ) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained( | |||||
| model_dir) | |||||
| tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer | |||||
| return tokenizer.from_pretrained(model_dir) | |||||
| elif model_type == Models.veco: | elif model_type == Models.veco: | ||||
| from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast | from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast | ||||
| return VecoTokenizer.from_pretrained( | |||||
| model_dir | |||||
| ) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained( | |||||
| model_dir) | |||||
| tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer | |||||
| return tokenizer.from_pretrained(model_dir) | |||||
| elif model_type == Models.deberta_v2: | elif model_type == Models.deberta_v2: | ||||
| from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast | from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast | ||||
| return DebertaV2Tokenizer.from_pretrained( | |||||
| model_dir | |||||
| ) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained( | |||||
| model_dir) | |||||
| tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer | |||||
| return tokenizer.from_pretrained(model_dir) | |||||
| elif not self.is_transformer_based_model: | elif not self.is_transformer_based_model: | ||||
| from transformers import BertTokenizer, BertTokenizerFast | from transformers import BertTokenizer, BertTokenizerFast | ||||
| return BertTokenizer.from_pretrained( | |||||
| model_dir | |||||
| ) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained( | |||||
| model_dir) | |||||
| tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer | |||||
| return tokenizer.from_pretrained(model_dir) | |||||
| else: | else: | ||||
| return AutoTokenizer.from_pretrained( | return AutoTokenizer.from_pretrained( | ||||
| model_dir, | |||||
| use_fast=False if self._mode == ModeKeys.INFERENCE else True) | |||||
| model_dir, use_fast=self.use_fast) | |||||
| def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: | def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: | ||||
| """process the raw input data | """process the raw input data | ||||
| @@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| return isinstance(label, str) or isinstance(label, int) | return isinstance(label, str) or isinstance(label, int) | ||||
| if labels is not None: | if labels is not None: | ||||
| if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
| if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
| and self.label2id is not None: | and self.label2id is not None: | ||||
| output[OutputKeys.LABELS] = [ | output[OutputKeys.LABELS] = [ | ||||
| self.label2id[str(label)] for label in labels | self.label2id[str(label)] for label in labels | ||||
| @@ -245,9 +245,9 @@ class NLPPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| """The tokenizer preprocessor used in passage ranking model. | |||||
| Fields.nlp, module_name=Preprocessors.text_ranking) | |||||
| class TextRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| """The tokenizer preprocessor used in text-ranking model. | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | ||||
| kwargs['truncation'] = kwargs.get('truncation', True) | kwargs['truncation'] = kwargs.get('truncation', True) | ||||
| kwargs['padding'] = kwargs.get( | |||||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | kwargs['max_length'] = kwargs.pop('sequence_length', 128) | ||||
| super().__init__(model_dir, mode=mode, **kwargs) | super().__init__(model_dir, mode=mode, **kwargs) | ||||
| @@ -594,9 +593,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| else: | else: | ||||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | self.is_split_into_words = self.tokenizer.init_kwargs.get( | ||||
| 'is_split_into_words', False) | 'is_split_into_words', False) | ||||
| if 'label2id' in kwargs: | |||||
| kwargs.pop('label2id') | |||||
| self.tokenize_kwargs = kwargs | |||||
| @type_assert(object, str) | @type_assert(object, str) | ||||
| def __call__(self, data: str) -> Dict[str, Any]: | def __call__(self, data: str) -> Dict[str, Any]: | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .image_captioning import OfaImageCaptioningPreprocessor | from .image_captioning import OfaImageCaptioningPreprocessor | ||||
| from .image_classification import OfaImageClassificationPreprocessor | from .image_classification import OfaImageClassificationPreprocessor | ||||
| from .ocr_recognition import OfaOcrRecognitionPreprocessor | |||||
| from .summarization import OfaSummarizationPreprocessor | from .summarization import OfaSummarizationPreprocessor | ||||
| from .text_classification import OfaTextClassificationPreprocessor | from .text_classification import OfaTextClassificationPreprocessor | ||||
| from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | ||||
| @@ -6,9 +6,12 @@ from os import path as osp | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from PIL import Image | |||||
| from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | ||||
| from modelscope.preprocessors.image import load_image | |||||
| from modelscope.utils.trie import Trie | from modelscope.utils.trie import Trie | ||||
| from .utils.constant import OFA_TASK_KEY_MAPPING | |||||
| from .utils.random_help import set_torch_seed | from .utils.random_help import set_torch_seed | ||||
| @@ -59,6 +62,14 @@ class OfaBasePreprocessor: | |||||
| self.mean = [0.5, 0.5, 0.5] | self.mean = [0.5, 0.5, 0.5] | ||||
| self.std = [0.5, 0.5, 0.5] | self.std = [0.5, 0.5, 0.5] | ||||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | ||||
| self.column_map = { | |||||
| key: key | |||||
| for key in OFA_TASK_KEY_MAPPING[self.cfg.task] | |||||
| } | |||||
| if hasattr(self.cfg, | |||||
| 'dataset') and self.cfg.dataset.column_map is not None: | |||||
| for k, v in self.cfg.dataset.column_map.items(): | |||||
| self.column_map[k] = v | |||||
| self.transtab = str.maketrans( | self.transtab = str.maketrans( | ||||
| {key: None | {key: None | ||||
| for key in string.punctuation}) | for key in string.punctuation}) | ||||
| @@ -147,3 +158,8 @@ class OfaBasePreprocessor: | |||||
| constraint_prefix_token) | constraint_prefix_token) | ||||
| constraint_mask[i][constraint_nodes] = True | constraint_mask[i][constraint_nodes] = True | ||||
| sample['constraint_mask'] = constraint_mask | sample['constraint_mask'] = constraint_mask | ||||
| def get_img_pil(self, path_or_url_or_pil): | |||||
| image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | |||||
| else load_image(path_or_url_or_pil) | |||||
| return image | |||||
| @@ -1,12 +1,9 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict | |||||
| import torch | import torch | ||||
| from PIL import Image | |||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from modelscope.preprocessors.image import load_image | |||||
| from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
| from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
| @@ -46,7 +43,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| sample = self._build_infer_sample(data) | sample = self._build_infer_sample(data) | ||||
| target = data['text'] | |||||
| target = data[self.column_map['text']] | |||||
| target = target.translate(self.transtab).strip() | target = target.translate(self.transtab).strip() | ||||
| target_token_list = target.strip().split() | target_token_list = target.strip().split() | ||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | target = ' '.join(target_token_list[:self.max_tgt_length]) | ||||
| @@ -56,8 +53,7 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| return sample | return sample | ||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| image = self.get_img_pil(data[self.column_map['image']]) | |||||
| patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
| prompt = self.cfg.model.get('prompt', ' what does the image describe?') | prompt = self.cfg.model.get('prompt', ' what does the image describe?') | ||||
| inputs = self.tokenize_text(prompt) | inputs = self.tokenize_text(prompt) | ||||
| @@ -66,6 +62,6 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_image': patch_image, | 'patch_image': patch_image, | ||||
| 'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
| } | } | ||||
| if 'text' in data: | |||||
| sample['label'] = data['text'] | |||||
| if self.column_map['text'] in data: | |||||
| sample['label'] = data[self.column_map['text']] | |||||
| return sample | return sample | ||||
| @@ -0,0 +1,97 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from torchvision.transforms import InterpolationMode | |||||
| from torchvision.transforms import functional as F | |||||
| from modelscope.preprocessors.image import load_image | |||||
| from .base import OfaBasePreprocessor | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| def ocr_resize(img, patch_image_size, is_document=False): | |||||
| img = img.convert('RGB') | |||||
| width, height = img.size | |||||
| if is_document: | |||||
| new_height, new_width = 64, 1920 | |||||
| else: | |||||
| if width >= height: | |||||
| new_width = max(64, patch_image_size) | |||||
| new_height = max(64, int(patch_image_size * (height / width))) | |||||
| top = (patch_image_size - new_height) // 2 | |||||
| bottom = patch_image_size - new_height - top | |||||
| left, right = 0, 0 | |||||
| else: | |||||
| new_height = max(64, patch_image_size) | |||||
| new_width = max(64, int(patch_image_size * (width / height))) | |||||
| left = (patch_image_size - new_width) // 2 | |||||
| right = patch_image_size - new_width - left | |||||
| top, bottom = 0, 0 | |||||
| img_new = F.resize( | |||||
| img, | |||||
| (new_height, new_width), | |||||
| interpolation=InterpolationMode.BICUBIC, | |||||
| ) | |||||
| if is_document: | |||||
| img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) | |||||
| img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) | |||||
| new_width, new_height = img_new.size | |||||
| top = (patch_image_size - new_height) // 2 | |||||
| bottom = patch_image_size - new_height - top | |||||
| left, right = 0, 0 | |||||
| img_new = F.pad( | |||||
| img_new, padding=[left, top, right, bottom], padding_mode='edge') | |||||
| assert img_new.size == (patch_image_size, patch_image_size) | |||||
| return img_new | |||||
| class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| def __init__(self, cfg, model_dir): | |||||
| """preprocess the data | |||||
| Args: | |||||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) | |||||
| # Initialize transform | |||||
| if self.cfg.model.imagenet_default_mean_and_std: | |||||
| mean = IMAGENET_DEFAULT_MEAN | |||||
| std = IMAGENET_DEFAULT_STD | |||||
| else: | |||||
| mean = [0.5, 0.5, 0.5] | |||||
| std = [0.5, 0.5, 0.5] | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: ocr_resize( | |||||
| image, | |||||
| self.cfg.model.patch_image_size, | |||||
| is_document=self.cfg.model.is_document), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=mean, std=std), | |||||
| ]) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| patch_image = self.patch_resize_transform(image) | |||||
| prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | |||||
| inputs = self.get_inputs(prompt) | |||||
| sample = { | |||||
| 'source': inputs, | |||||
| 'patch_image': patch_image, | |||||
| 'patch_mask': torch.tensor([True]) | |||||
| } | |||||
| return sample | |||||
| @@ -0,0 +1,13 @@ | |||||
| from modelscope.utils.constant import Tasks | |||||
| OFA_TASK_KEY_MAPPING = { | |||||
| Tasks.ocr_recognition: ['image'], | |||||
| Tasks.image_captioning: ['image'], | |||||
| Tasks.image_classification: ['image'], | |||||
| Tasks.text_summarization: ['text'], | |||||
| Tasks.text_classification: ['text', 'text2'], | |||||
| Tasks.visual_grounding: ['image', 'text'], | |||||
| Tasks.visual_question_answering: ['image', 'text'], | |||||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | |||||
| Tasks.text_to_image_synthesis: ['text'] | |||||
| } | |||||
| @@ -13,7 +13,7 @@ class Database: | |||||
| tokenizer, | tokenizer, | ||||
| table_file_path, | table_file_path, | ||||
| syn_dict_file_path, | syn_dict_file_path, | ||||
| is_use_sqlite=False): | |||||
| is_use_sqlite=True): | |||||
| self.tokenizer = tokenizer | self.tokenizer = tokenizer | ||||
| self.is_use_sqlite = is_use_sqlite | self.is_use_sqlite = is_use_sqlite | ||||
| if self.is_use_sqlite: | if self.is_use_sqlite: | ||||