Conflicts: modelscope/utils/constant.pymaster
| @@ -0,0 +1,265 @@ | |||||
| import imp | |||||
| import os | |||||
| import pickle | |||||
| import subprocess | |||||
| from http.cookiejar import CookieJar | |||||
| from os.path import expanduser | |||||
| from typing import List, Optional, Tuple, Union | |||||
| import requests | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .constants import LOGGER_NAME | |||||
| from .errors import NotExistError, is_ok, raise_on_error | |||||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||||
| logger = get_logger() | |||||
| class HubApi: | |||||
| def __init__(self, endpoint=None): | |||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||||
| def login( | |||||
| self, | |||||
| user_name: str, | |||||
| password: str, | |||||
| ) -> tuple(): | |||||
| """ | |||||
| Login with username and password | |||||
| Args: | |||||
| username(`str`): user name on modelscope | |||||
| password(`str`): password | |||||
| Returns: | |||||
| cookies: to authenticate yourself to ModelScope open-api | |||||
| gitlab token: to access private repos | |||||
| <Tip> | |||||
| You only have to login once within 30 days. | |||||
| </Tip> | |||||
| TODO: handle cookies expire | |||||
| """ | |||||
| path = f'{self.endpoint}/api/v1/login' | |||||
| r = requests.post( | |||||
| path, json={ | |||||
| 'username': user_name, | |||||
| 'password': password | |||||
| }) | |||||
| r.raise_for_status() | |||||
| d = r.json() | |||||
| raise_on_error(d) | |||||
| token = d['Data']['AccessToken'] | |||||
| cookies = r.cookies | |||||
| # save token and cookie | |||||
| ModelScopeConfig.save_token(token) | |||||
| ModelScopeConfig.save_cookies(cookies) | |||||
| ModelScopeConfig.write_to_git_credential(user_name, password) | |||||
| return d['Data']['AccessToken'], cookies | |||||
| def create_model(self, model_id: str, chinese_name: str, visibility: int, | |||||
| license: str) -> str: | |||||
| """ | |||||
| Create model repo at ModelScopeHub | |||||
| Args: | |||||
| model_id:(`str`): The model id | |||||
| chinese_name(`str`): chinese name of the model | |||||
| visibility(`int`): visibility of the model(1-private, 3-internal, 5-public) | |||||
| license(`str`): license of the model, candidates can be found at: TBA | |||||
| Returns: | |||||
| name of the model created | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise ValueError('Token does not exist, please login first.') | |||||
| path = f'{self.endpoint}/api/v1/models' | |||||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||||
| r = requests.post( | |||||
| path, | |||||
| json={ | |||||
| 'Path': owner_or_group, | |||||
| 'Name': name, | |||||
| 'ChineseName': chinese_name, | |||||
| 'Visibility': visibility, | |||||
| 'License': license | |||||
| }, | |||||
| cookies=cookies) | |||||
| r.raise_for_status() | |||||
| raise_on_error(r.json()) | |||||
| d = r.json() | |||||
| return d['Data']['Name'] | |||||
| def delete_model(self, model_id): | |||||
| """_summary_ | |||||
| Args: | |||||
| model_id (str): The model id. | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}' | |||||
| r = requests.delete(path, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| raise_on_error(r.json()) | |||||
| def get_model_url(self, model_id): | |||||
| return f'{self.endpoint}/api/v1/models/{model_id}.git' | |||||
| def get_model( | |||||
| self, | |||||
| model_id: str, | |||||
| revision: str = 'master', | |||||
| ) -> str: | |||||
| """ | |||||
| Get model information at modelscope_hub | |||||
| Args: | |||||
| model_id(`str`): The model id. | |||||
| revision(`str`): revision of model | |||||
| Returns: | |||||
| The model details information. | |||||
| Raises: | |||||
| NotExistError: If the model is not exist, will throw NotExistError | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}' | |||||
| r = requests.get(path, cookies=cookies) | |||||
| if r.status_code == 200: | |||||
| if is_ok(r.json()): | |||||
| return r.json()['Data'] | |||||
| else: | |||||
| raise NotExistError(r.json()['Message']) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| def get_model_branches_and_tags( | |||||
| self, | |||||
| model_id: str, | |||||
| ) -> Tuple[List[str], List[str]]: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||||
| r = requests.get(path, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| d = r.json() | |||||
| raise_on_error(d) | |||||
| info = d['Data'] | |||||
| branches = [x['Revision'] for x in info['RevisionMap']['Branches'] | |||||
| ] if info['RevisionMap']['Branches'] else [] | |||||
| tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||||
| ] if info['RevisionMap']['Tags'] else [] | |||||
| return branches, tags | |||||
| def get_model_files( | |||||
| self, | |||||
| model_id: str, | |||||
| revision: Optional[str] = 'master', | |||||
| root: Optional[str] = None, | |||||
| recursive: Optional[str] = False, | |||||
| use_cookies: Union[bool, CookieJar] = False) -> List[dict]: | |||||
| cookies = None | |||||
| if isinstance(use_cookies, CookieJar): | |||||
| cookies = use_cookies | |||||
| elif use_cookies: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise ValueError('Token does not exist, please login first.') | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}' | |||||
| if root is not None: | |||||
| path = path + f'&Root={root}' | |||||
| r = requests.get(path, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| d = r.json() | |||||
| raise_on_error(d) | |||||
| files = [] | |||||
| for file in d['Data']['Files']: | |||||
| if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': | |||||
| continue | |||||
| files.append(file) | |||||
| return files | |||||
| class ModelScopeConfig: | |||||
| path_credential = expanduser('~/.modelscope/credentials') | |||||
| os.makedirs(path_credential, exist_ok=True) | |||||
| @classmethod | |||||
| def save_cookies(cls, cookies: CookieJar): | |||||
| with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||||
| pickle.dump(cookies, f) | |||||
| @classmethod | |||||
| def get_cookies(cls): | |||||
| try: | |||||
| with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||||
| return pickle.load(f) | |||||
| except FileNotFoundError: | |||||
| logger.warn("Auth token does not exist, you'll get authentication \ | |||||
| error when downloading private model files. Please login first" | |||||
| ) | |||||
| @classmethod | |||||
| def save_token(cls, token: str): | |||||
| with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||||
| f.write(token) | |||||
| @classmethod | |||||
| def get_token(cls) -> Optional[str]: | |||||
| """ | |||||
| Get token or None if not existent. | |||||
| Returns: | |||||
| `str` or `None`: The token, `None` if it doesn't exist. | |||||
| """ | |||||
| token = None | |||||
| try: | |||||
| with open(os.path.join(cls.path_credential, 'token'), 'r') as f: | |||||
| token = f.read() | |||||
| except FileNotFoundError: | |||||
| pass | |||||
| return token | |||||
| @staticmethod | |||||
| def write_to_git_credential(username: str, password: str): | |||||
| with subprocess.Popen( | |||||
| 'git credential-store store'.split(), | |||||
| stdin=subprocess.PIPE, | |||||
| stdout=subprocess.PIPE, | |||||
| stderr=subprocess.STDOUT, | |||||
| ) as process: | |||||
| input_username = f'username={username.lower()}' | |||||
| input_password = f'password={password}' | |||||
| process.stdin.write( | |||||
| f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n' | |||||
| .encode('utf-8')) | |||||
| process.stdin.flush() | |||||
| @@ -0,0 +1,8 @@ | |||||
| MODELSCOPE_URL_SCHEME = 'http://' | |||||
| DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330' | |||||
| DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' | |||||
| DEFAULT_MODELSCOPE_GROUP = 'damo' | |||||
| MODEL_ID_SEPARATOR = '/' | |||||
| LOGGER_NAME = 'ModelScopeHub' | |||||
| @@ -0,0 +1,30 @@ | |||||
| class NotExistError(Exception): | |||||
| pass | |||||
| class RequestError(Exception): | |||||
| pass | |||||
| def is_ok(rsp): | |||||
| """ Check the request is ok | |||||
| Args: | |||||
| rsp (_type_): The request response body | |||||
| Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission', | |||||
| 'RequestId': '', 'Success': False} | |||||
| Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} | |||||
| """ | |||||
| return rsp['Code'] == 200 and rsp['Success'] | |||||
| def raise_on_error(rsp): | |||||
| """If response error, raise exception | |||||
| Args: | |||||
| rsp (_type_): The server response | |||||
| """ | |||||
| if rsp['Code'] == 200 and rsp['Success']: | |||||
| return True | |||||
| else: | |||||
| raise RequestError(rsp['Message']) | |||||
| @@ -0,0 +1,254 @@ | |||||
| import copy | |||||
| import fnmatch | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| import tempfile | |||||
| import time | |||||
| from functools import partial | |||||
| from hashlib import sha256 | |||||
| from pathlib import Path | |||||
| from typing import BinaryIO, Dict, Optional, Union | |||||
| from uuid import uuid4 | |||||
| import json | |||||
| import requests | |||||
| from filelock import FileLock | |||||
| from requests.exceptions import HTTPError | |||||
| from tqdm import tqdm | |||||
| from modelscope import __version__ | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME, | |||||
| MODEL_ID_SEPARATOR) | |||||
| from .errors import NotExistError, RequestError, raise_on_error | |||||
| from .utils.caching import ModelFileSystemCache | |||||
| from .utils.utils import (get_cache_dir, get_endpoint, | |||||
| model_id_to_group_owner_name) | |||||
| SESSION_ID = uuid4().hex | |||||
| logger = get_logger() | |||||
| def model_file_download( | |||||
| model_id: str, | |||||
| file_path: str, | |||||
| revision: Optional[str] = 'master', | |||||
| cache_dir: Optional[str] = None, | |||||
| user_agent: Union[Dict, str, None] = None, | |||||
| local_files_only: Optional[bool] = False, | |||||
| ) -> Optional[str]: # pragma: no cover | |||||
| """ | |||||
| Download from a given URL and cache it if it's not already present in the | |||||
| local cache. | |||||
| Given a URL, this function looks for the corresponding file in the local | |||||
| cache. If it's not there, download it. Then return the path to the cached | |||||
| file. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model to whom the file to be downloaded belongs. | |||||
| file_path(`str`): | |||||
| Path of the file to be downloaded, relative to the root of model repo | |||||
| revision(`str`, *optional*): | |||||
| revision of the model file to be downloaded. | |||||
| Can be any of a branch, tag or commit hash, default to `master` | |||||
| cache_dir (`str`, `Path`, *optional*): | |||||
| Path to the folder where cached files are stored. | |||||
| user_agent (`dict`, `str`, *optional*): | |||||
| The user-agent info in the form of a dictionary or a string. | |||||
| local_files_only (`bool`, *optional*, defaults to `False`): | |||||
| If `True`, avoid downloading the file and return the path to the | |||||
| local cached file if it exists. | |||||
| if `False`, download the file anyway even it exists | |||||
| Returns: | |||||
| Local path (string) of file or if networking is off, last version of | |||||
| file cached on disk. | |||||
| <Tip> | |||||
| Raises the following errors: | |||||
| - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) | |||||
| if `use_auth_token=True` and the token cannot be found. | |||||
| - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) | |||||
| if ETag cannot be determined. | |||||
| - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) | |||||
| if some parameter value is invalid | |||||
| </Tip> | |||||
| """ | |||||
| if cache_dir is None: | |||||
| cache_dir = get_cache_dir() | |||||
| if isinstance(cache_dir, Path): | |||||
| cache_dir = str(cache_dir) | |||||
| group_or_owner, name = model_id_to_group_owner_name(model_id) | |||||
| cache = ModelFileSystemCache(cache_dir, group_or_owner, name) | |||||
| # if local_files_only is `True` and the file already exists in cached_path | |||||
| # return the cached path | |||||
| if local_files_only: | |||||
| cached_file_path = cache.get_file_by_path(file_path) | |||||
| if cached_file_path is not None: | |||||
| logger.warning( | |||||
| "File exists in local cache, but we're not sure it's up to date" | |||||
| ) | |||||
| return cached_file_path | |||||
| else: | |||||
| raise ValueError( | |||||
| 'Cannot find the requested files in the cached path and outgoing' | |||||
| ' traffic has been disabled. To enable model look-ups and downloads' | |||||
| " online, set 'local_files_only' to False.") | |||||
| _api = HubApi() | |||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||||
| file_to_download_info = None | |||||
| is_commit_id = False | |||||
| if revision in branches or revision in tags: # The revision is version or tag, | |||||
| # we need to confirm the version is up to date | |||||
| # we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| ) | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| if model_file['Path'] == file_path: | |||||
| model_file['Branch'] = revision | |||||
| if cache.exists(model_file): | |||||
| return cache.get_file_by_info(model_file) | |||||
| else: | |||||
| file_to_download_info = model_file | |||||
| if file_to_download_info is None: | |||||
| raise NotExistError('The file path: %s not exist in: %s' % | |||||
| (file_path, model_id)) | |||||
| else: # the revision is commit id. | |||||
| cached_file_path = cache.get_file_by_path_and_commit_id( | |||||
| file_path, revision) | |||||
| if cached_file_path is not None: | |||||
| logger.info('The specified file is in cache, skip downloading!') | |||||
| return cached_file_path # the file is in cache. | |||||
| is_commit_id = True | |||||
| # we need to download again | |||||
| # TODO: skip using JWT for authorization, use cookie instead | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| url_to_download = get_file_download_url(model_id, file_path, revision) | |||||
| file_to_download_info = { | |||||
| 'Path': file_path, | |||||
| 'Revision': | |||||
| revision if is_commit_id else file_to_download_info['Revision'] | |||||
| } | |||||
| # Prevent parallel downloads of the same file with a lock. | |||||
| lock_path = cache.get_root_location() + '.lock' | |||||
| with FileLock(lock_path): | |||||
| temp_file_name = next(tempfile._get_candidate_names()) | |||||
| http_get_file( | |||||
| url_to_download, | |||||
| cache_dir, | |||||
| temp_file_name, | |||||
| headers=headers, | |||||
| cookies=None if cookies is None else cookies.get_dict()) | |||||
| return cache.put_file(file_to_download_info, | |||||
| os.path.join(cache_dir, temp_file_name)) | |||||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||||
| """Formats a user-agent string with basic info about a request. | |||||
| Args: | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user agent info in the form of a dictionary or a single string. | |||||
| Returns: | |||||
| The formatted user-agent string. | |||||
| """ | |||||
| ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||||
| if isinstance(user_agent, dict): | |||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||||
| elif isinstance(user_agent, str): | |||||
| ua = user_agent | |||||
| return ua | |||||
| def get_file_download_url(model_id: str, file_path: str, revision: str): | |||||
| """ | |||||
| Format file download url according to `model_id`, `revision` and `file_path`. | |||||
| e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, | |||||
| the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md | |||||
| """ | |||||
| download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' | |||||
| return download_url_template.format( | |||||
| endpoint=get_endpoint(), | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| file_path=file_path, | |||||
| ) | |||||
| def http_get_file( | |||||
| url: str, | |||||
| local_dir: str, | |||||
| file_name: str, | |||||
| cookies: Dict[str, str], | |||||
| headers: Optional[Dict[str, str]] = None, | |||||
| ): | |||||
| """ | |||||
| Download remote file. Do not gobble up errors. | |||||
| This method is only used by snapshot_download, since the behavior is quite different with single file download | |||||
| TODO: consolidate with http_get_file() to avoild duplicate code | |||||
| Args: | |||||
| url(`str`): | |||||
| actual download url of the file | |||||
| local_dir(`str`): | |||||
| local directory where the downloaded file stores | |||||
| file_name(`str`): | |||||
| name of the file stored in `local_dir` | |||||
| cookies(`Dict[str, str]`): | |||||
| cookies used to authentication the user, which is used for downloading private repos | |||||
| headers(`Optional[Dict[str, str]] = None`): | |||||
| http headers to carry necessary info when requesting the remote file | |||||
| """ | |||||
| temp_file_manager = partial( | |||||
| tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | |||||
| with temp_file_manager() as temp_file: | |||||
| logger.info('downloading %s to %s', url, temp_file.name) | |||||
| headers = copy.deepcopy(headers) | |||||
| r = requests.get(url, stream=True, headers=headers, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| content_length = r.headers.get('Content-Length') | |||||
| total = int(content_length) if content_length is not None else None | |||||
| progress = tqdm( | |||||
| unit='B', | |||||
| unit_scale=True, | |||||
| unit_divisor=1024, | |||||
| total=total, | |||||
| initial=0, | |||||
| desc='Downloading', | |||||
| ) | |||||
| for chunk in r.iter_content(chunk_size=1024): | |||||
| if chunk: # filter out keep-alive new chunks | |||||
| progress.update(len(chunk)) | |||||
| temp_file.write(chunk) | |||||
| progress.close() | |||||
| logger.info('storing %s in cache at %s', url, local_dir) | |||||
| os.replace(temp_file.name, os.path.join(local_dir, file_name)) | |||||
| @@ -0,0 +1,82 @@ | |||||
| from threading import local | |||||
| from tkinter.messagebox import NO | |||||
| from typing import Union | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .constants import LOGGER_NAME | |||||
| from .utils._subprocess import run_subprocess | |||||
| logger = get_logger | |||||
| def git_clone( | |||||
| local_dir: str, | |||||
| repo_url: str, | |||||
| ): | |||||
| # TODO: use "git clone" or "git lfs clone" according to git version | |||||
| # TODO: print stderr when subprocess fails | |||||
| run_subprocess( | |||||
| f'git clone {repo_url}'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_checkout( | |||||
| local_dir: str, | |||||
| revsion: str, | |||||
| ): | |||||
| run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||||
| def git_add(local_dir: str, ): | |||||
| run_subprocess( | |||||
| 'git add .'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_commit(local_dir: str, commit_message: str): | |||||
| run_subprocess( | |||||
| 'git commit -v -m'.split() + [commit_message], | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_push(local_dir: str, branch: str): | |||||
| # check current branch | |||||
| cur_branch = git_current_branch(local_dir) | |||||
| if cur_branch != branch: | |||||
| logger.error( | |||||
| "You're trying to push to a different branch, please double check") | |||||
| return | |||||
| run_subprocess( | |||||
| f'git push origin {branch}'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_current_branch(local_dir: str) -> Union[str, None]: | |||||
| """ | |||||
| Get current branch name | |||||
| Args: | |||||
| local_dir(`str`): local model repo directory | |||||
| Returns | |||||
| branch name you're currently on | |||||
| """ | |||||
| try: | |||||
| process = run_subprocess( | |||||
| 'git rev-parse --abbrev-ref HEAD'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| return str(process.stdout).strip() | |||||
| except Exception as e: | |||||
| raise e | |||||
| @@ -0,0 +1,173 @@ | |||||
| import os | |||||
| import subprocess | |||||
| from pathlib import Path | |||||
| from typing import Optional, Union | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .api import ModelScopeConfig | |||||
| from .constants import MODELSCOPE_URL_SCHEME | |||||
| from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||||
| from .utils._subprocess import run_subprocess | |||||
| from .utils.utils import get_gitlab_domain | |||||
| logger = get_logger() | |||||
| class Repository: | |||||
| def __init__( | |||||
| self, | |||||
| local_dir: str, | |||||
| clone_from: Optional[str] = None, | |||||
| auth_token: Optional[str] = None, | |||||
| private: Optional[bool] = False, | |||||
| revision: Optional[str] = 'master', | |||||
| ): | |||||
| """ | |||||
| Instantiate a Repository object by cloning the remote ModelScopeHub repo | |||||
| Args: | |||||
| local_dir(`str`): | |||||
| local directory to store the model files | |||||
| clone_from(`Optional[str] = None`): | |||||
| model id in ModelScope-hub from which git clone | |||||
| You should ignore this parameter when `local_dir` is already a git repo | |||||
| auth_token(`Optional[str]`): | |||||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||||
| as the token is already saved when you login the first time | |||||
| private(`Optional[bool]`): | |||||
| whether the model is private, default to False | |||||
| revision(`Optional[str]`): | |||||
| revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||||
| """ | |||||
| logger.info('Instantiating Repository object...') | |||||
| # Create local directory if not exist | |||||
| os.makedirs(local_dir, exist_ok=True) | |||||
| self.local_dir = os.path.join(os.getcwd(), local_dir) | |||||
| self.private = private | |||||
| # Check git and git-lfs installation | |||||
| self.check_git_versions() | |||||
| # Retrieve auth token | |||||
| if not private and isinstance(auth_token, str): | |||||
| logger.warning( | |||||
| 'cloning a public repo with a token, which will be ignored') | |||||
| self.token = None | |||||
| else: | |||||
| if isinstance(auth_token, str): | |||||
| self.token = auth_token | |||||
| else: | |||||
| self.token = ModelScopeConfig.get_token() | |||||
| if self.token is None: | |||||
| raise EnvironmentError( | |||||
| 'Token does not exist, the clone will fail for private repo.' | |||||
| 'Please login first.') | |||||
| # git clone | |||||
| if clone_from is not None: | |||||
| self.model_id = clone_from | |||||
| logger.info('cloning model repo to %s ...', self.local_dir) | |||||
| git_clone(self.local_dir, self.get_repo_url()) | |||||
| else: | |||||
| if is_git_repo(self.local_dir): | |||||
| logger.debug('[Repository] is a valid git repo') | |||||
| else: | |||||
| raise ValueError( | |||||
| 'If not specifying `clone_from`, you need to pass Repository a' | |||||
| ' valid git clone.') | |||||
| # git checkout | |||||
| if isinstance(revision, str) and revision != 'master': | |||||
| git_checkout(revision) | |||||
| def push_to_hub(self, | |||||
| commit_message: str, | |||||
| revision: Optional[str] = 'master'): | |||||
| """ | |||||
| Push changes changes to hub | |||||
| Args: | |||||
| commit_message(`str`): | |||||
| commit message describing the changes, it's mandatory | |||||
| revision(`Optional[str]`): | |||||
| remote branch you want to push to, default to `master` | |||||
| <Tip> | |||||
| The function complains when local and remote branch are different, please be careful | |||||
| </Tip> | |||||
| """ | |||||
| git_add(self.local_dir) | |||||
| git_commit(self.local_dir, commit_message) | |||||
| logger.info('Pushing changes to repo...') | |||||
| git_push(self.local_dir, revision) | |||||
| # TODO: if git push fails, how to retry? | |||||
| def check_git_versions(self): | |||||
| """ | |||||
| Checks that `git` and `git-lfs` can be run. | |||||
| Raises: | |||||
| `EnvironmentError`: if `git` or `git-lfs` are not installed. | |||||
| """ | |||||
| try: | |||||
| git_version = run_subprocess('git --version'.split(), | |||||
| self.local_dir).stdout.strip() | |||||
| except FileNotFoundError: | |||||
| raise EnvironmentError( | |||||
| 'Looks like you do not have git installed, please install.') | |||||
| try: | |||||
| lfs_version = run_subprocess('git-lfs --version'.split(), | |||||
| self.local_dir).stdout.strip() | |||||
| except FileNotFoundError: | |||||
| raise EnvironmentError( | |||||
| 'Looks like you do not have git-lfs installed, please install.' | |||||
| ' You can install from https://git-lfs.github.com/.' | |||||
| ' Then run `git lfs install` (you only have to do this once).') | |||||
| logger.info(git_version + '\n' + lfs_version) | |||||
| def get_repo_url(self) -> str: | |||||
| """ | |||||
| Get repo url to clone, according whether the repo is private or not | |||||
| """ | |||||
| url = None | |||||
| if self.private: | |||||
| url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||||
| else: | |||||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||||
| if not url: | |||||
| raise ValueError( | |||||
| 'Empty repo url, please check clone_from parameter') | |||||
| logger.debug('url to clone: %s', str(url)) | |||||
| return url | |||||
| def is_git_repo(folder: Union[str, Path]) -> bool: | |||||
| """ | |||||
| Check if the folder is the root or part of a git repository | |||||
| Args: | |||||
| folder (`str`): | |||||
| The folder in which to run the command. | |||||
| Returns: | |||||
| `bool`: `True` if the repository is part of a repository, `False` | |||||
| otherwise. | |||||
| """ | |||||
| folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||||
| git_branch = subprocess.run( | |||||
| 'git branch'.split(), | |||||
| cwd=folder, | |||||
| stdout=subprocess.PIPE, | |||||
| stderr=subprocess.PIPE) | |||||
| return folder_exists and git_branch.returncode == 0 | |||||
| @@ -0,0 +1,125 @@ | |||||
| import os | |||||
| import tempfile | |||||
| from glob import glob | |||||
| from pathlib import Path | |||||
| from typing import Dict, Optional, Union | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR | |||||
| from .errors import NotExistError, RequestError, raise_on_error | |||||
| from .file_download import (get_file_download_url, http_get_file, | |||||
| http_user_agent) | |||||
| from .utils.caching import ModelFileSystemCache | |||||
| from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||||
| logger = get_logger() | |||||
| def snapshot_download(model_id: str, | |||||
| revision: Optional[str] = 'master', | |||||
| cache_dir: Union[str, Path, None] = None, | |||||
| user_agent: Optional[Union[Dict, str]] = None, | |||||
| local_files_only: Optional[bool] = False, | |||||
| private: Optional[bool] = False) -> str: | |||||
| """Download all files of a repo. | |||||
| Downloads a whole snapshot of a repo's files at the specified revision. This | |||||
| is useful when you want all files from a repo, because you don't know which | |||||
| ones you will need a priori. All files are nested inside a folder in order | |||||
| to keep their actual filename relative to that folder. | |||||
| An alternative would be to just clone a repo but this would require that the | |||||
| user always has git and git-lfs installed, and properly configured. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| A user or an organization name and a repo name separated by a `/`. | |||||
| revision (`str`, *optional*): | |||||
| An optional Git revision id which can be a branch name, a tag, or a | |||||
| commit hash. NOTE: currently only branch and tag name is supported | |||||
| cache_dir (`str`, `Path`, *optional*): | |||||
| Path to the folder where cached files are stored. | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user-agent info in the form of a dictionary or a string. | |||||
| local_files_only (`bool`, *optional*, defaults to `False`): | |||||
| If `True`, avoid downloading the file and return the path to the | |||||
| local cached file if it exists. | |||||
| Returns: | |||||
| Local folder path (string) of repo snapshot | |||||
| <Tip> | |||||
| Raises the following errors: | |||||
| - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) | |||||
| if `use_auth_token=True` and the token cannot be found. | |||||
| - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if | |||||
| ETag cannot be determined. | |||||
| - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) | |||||
| if some parameter value is invalid | |||||
| </Tip> | |||||
| """ | |||||
| if cache_dir is None: | |||||
| cache_dir = get_cache_dir() | |||||
| if isinstance(cache_dir, Path): | |||||
| cache_dir = str(cache_dir) | |||||
| group_or_owner, name = model_id_to_group_owner_name(model_id) | |||||
| cache = ModelFileSystemCache(cache_dir, group_or_owner, name) | |||||
| if local_files_only: | |||||
| if len(cache.cached_files) == 0: | |||||
| raise ValueError( | |||||
| 'Cannot find the requested files in the cached path and outgoing' | |||||
| ' traffic has been disabled. To enable model look-ups and downloads' | |||||
| " online, set 'local_files_only' to False.") | |||||
| logger.warn('We can not confirm the cached file is for revision: %s' | |||||
| % revision) | |||||
| return cache.get_root_location( | |||||
| ) # we can not confirm the cached file is for snapshot 'revision' | |||||
| else: | |||||
| # make headers | |||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| _api = HubApi() | |||||
| # get file list from model repo | |||||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||||
| if revision not in branches and revision not in tags: | |||||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||||
| % revision) | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| use_cookies=private) | |||||
| cookies = None | |||||
| if private: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||||
| if cache.exists(model_file): | |||||
| logger.info( | |||||
| 'The specified file is in cache, skip downloading!') | |||||
| continue | |||||
| # get download url | |||||
| url = get_file_download_url( | |||||
| model_id=model_id, | |||||
| file_path=model_file['Path'], | |||||
| revision=revision) | |||||
| # First download to /tmp | |||||
| http_get_file( | |||||
| url=url, | |||||
| local_dir=tempfile.gettempdir(), | |||||
| file_name=model_file['Name'], | |||||
| headers=headers, | |||||
| cookies=None if cookies is None else cookies.get_dict()) | |||||
| # put file to cache | |||||
| cache.put_file( | |||||
| model_file, | |||||
| os.path.join(tempfile.gettempdir(), model_file['Name'])) | |||||
| return os.path.join(cache.get_root_location()) | |||||
| @@ -0,0 +1,40 @@ | |||||
| import subprocess | |||||
| from typing import List | |||||
| def run_subprocess(command: List[str], | |||||
| folder: str, | |||||
| check=True, | |||||
| **kwargs) -> subprocess.CompletedProcess: | |||||
| """ | |||||
| Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||||
| please call `subprocess.run` manually in case you would like for them not to | |||||
| be captured. | |||||
| Args: | |||||
| command (`List[str]`): | |||||
| The command to execute as a list of strings. | |||||
| folder (`str`): | |||||
| The folder in which to run the command. | |||||
| check (`bool`, *optional*, defaults to `True`): | |||||
| Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||||
| when the subprocess has a non-zero exit code. | |||||
| kwargs (`Dict[str]`): | |||||
| Keyword arguments to be passed to the `subprocess.run` underlying command. | |||||
| Returns: | |||||
| `subprocess.CompletedProcess`: The completed process. | |||||
| """ | |||||
| if isinstance(command, str): | |||||
| raise ValueError( | |||||
| '`run_subprocess` should be called with a list of strings.') | |||||
| return subprocess.run( | |||||
| command, | |||||
| stderr=subprocess.PIPE, | |||||
| stdout=subprocess.PIPE, | |||||
| check=check, | |||||
| encoding='utf-8', | |||||
| cwd=folder, | |||||
| **kwargs, | |||||
| ) | |||||
| @@ -0,0 +1,294 @@ | |||||
| import hashlib | |||||
| import logging | |||||
| import os | |||||
| import pickle | |||||
| import tempfile | |||||
| import time | |||||
| from shutil import move, rmtree | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class FileSystemCache(object): | |||||
| KEY_FILE_NAME = '.msc' | |||||
| """Local file cache. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| cache_root_location: str, | |||||
| **kwargs, | |||||
| ): | |||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| cache_location: str | |||||
| The root location to store files. | |||||
| """ | |||||
| os.makedirs(cache_root_location, exist_ok=True) | |||||
| self.cache_root_location = cache_root_location | |||||
| self.load_cache() | |||||
| def get_root_location(self): | |||||
| return self.cache_root_location | |||||
| def load_cache(self): | |||||
| """Read set of stored blocks from file | |||||
| Args: | |||||
| owner(`str`): individual or group username at modelscope, can be empty for official models | |||||
| name(`str`): name of the model | |||||
| Returns: | |||||
| The model details information. | |||||
| Raises: | |||||
| NotExistError: If the model is not exist, will throw NotExistError | |||||
| TODO: Error based error code. | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| self.cached_files = [] | |||||
| cache_keys_file_path = os.path.join(self.cache_root_location, | |||||
| FileSystemCache.KEY_FILE_NAME) | |||||
| if os.path.exists(cache_keys_file_path): | |||||
| with open(cache_keys_file_path, 'rb') as f: | |||||
| self.cached_files = pickle.load(f) | |||||
| def save_cached_files(self): | |||||
| """Save cache metadata.""" | |||||
| # save new meta to tmp and move to KEY_FILE_NAME | |||||
| cache_keys_file_path = os.path.join(self.cache_root_location, | |||||
| FileSystemCache.KEY_FILE_NAME) | |||||
| # TODO: Sync file write | |||||
| fd, fn = tempfile.mkstemp() | |||||
| with open(fd, 'wb') as f: | |||||
| pickle.dump(self.cached_files, f) | |||||
| move(fn, cache_keys_file_path) | |||||
| def get_file(self, key): | |||||
| """Check the key is in the cache, if exist, return the file, otherwise return None. | |||||
| Args: | |||||
| key(`str`): The cache key. | |||||
| Returns: | |||||
| If file exist, return the cached file location, otherwise None. | |||||
| Raises: | |||||
| None | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| pass | |||||
| def put_file(self, key, location): | |||||
| """Put file to the cache, | |||||
| Args: | |||||
| key(`str`): The cache key | |||||
| location(`str`): Location of the file, we will move the file to cache. | |||||
| Returns: | |||||
| The cached file path of the file. | |||||
| Raises: | |||||
| None | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| pass | |||||
| def remove_key(self, key): | |||||
| """Remove cache key in index, The file is removed manually | |||||
| Args: | |||||
| key (dict): The cache key. | |||||
| """ | |||||
| self.cached_files.remove(key) | |||||
| self.save_cached_files() | |||||
| def exists(self, key): | |||||
| for cache_file in self.cached_files: | |||||
| if cache_file == key: | |||||
| return True | |||||
| return False | |||||
| def clear_cache(self): | |||||
| """Remove all files and metadat from the cache | |||||
| In the case of multiple cache locations, this clears only the last one, | |||||
| which is assumed to be the read/write one. | |||||
| """ | |||||
| rmtree(self.cache_root_location) | |||||
| self.load_cache() | |||||
| def hash_name(self, key): | |||||
| return hashlib.sha256(key.encode()).hexdigest() | |||||
| class ModelFileSystemCache(FileSystemCache): | |||||
| """Local cache file layout | |||||
| cache_root/owner/model_name/|individual cached files | |||||
| |.mk: file, The cache index file | |||||
| Save only one version for each file. | |||||
| """ | |||||
| def __init__(self, cache_root, owner, name): | |||||
| """Put file to the cache | |||||
| Args: | |||||
| cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/) | |||||
| owner(`str`): The model owner. | |||||
| name('str'): The name of the model | |||||
| branch('str'): The branch of model | |||||
| tag('str'): The tag of model | |||||
| Returns: | |||||
| Raises: | |||||
| None | |||||
| <Tip> | |||||
| model_id = {owner}/{name} | |||||
| </Tip> | |||||
| """ | |||||
| super().__init__(os.path.join(cache_root, owner, name)) | |||||
| def get_file_by_path(self, file_path): | |||||
| """Retrieve the cache if there is file match the path. | |||||
| Args: | |||||
| file_path (str): The file path in the model. | |||||
| Returns: | |||||
| path: the full path of the file. | |||||
| """ | |||||
| for cached_file in self.cached_files: | |||||
| if file_path == cached_file['Path']: | |||||
| cached_file_path = os.path.join(self.cache_root_location, | |||||
| cached_file['Path']) | |||||
| if os.path.exists(cached_file_path): | |||||
| return cached_file_path | |||||
| else: | |||||
| self.remove_key(cached_file) | |||||
| return None | |||||
| def get_file_by_path_and_commit_id(self, file_path, commit_id): | |||||
| """Retrieve the cache if there is file match the path. | |||||
| Args: | |||||
| file_path (str): The file path in the model. | |||||
| commit_id (str): The commit id of the file | |||||
| Returns: | |||||
| path: the full path of the file. | |||||
| """ | |||||
| for cached_file in self.cached_files: | |||||
| if file_path == cached_file['Path'] and \ | |||||
| (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): | |||||
| cached_file_path = os.path.join(self.cache_root_location, | |||||
| cached_file['Path']) | |||||
| if os.path.exists(cached_file_path): | |||||
| return cached_file_path | |||||
| else: | |||||
| self.remove_key(cached_file) | |||||
| return None | |||||
| def get_file_by_info(self, model_file_info): | |||||
| """Check if exist cache file. | |||||
| Args: | |||||
| model_file_info (ModelFileInfo): The file information of the file. | |||||
| Returns: | |||||
| _type_: _description_ | |||||
| """ | |||||
| cache_key = self.__get_cache_key(model_file_info) | |||||
| for cached_file in self.cached_files: | |||||
| if cached_file == cache_key: | |||||
| orig_path = os.path.join(self.cache_root_location, | |||||
| cached_file['Path']) | |||||
| if os.path.exists(orig_path): | |||||
| return orig_path | |||||
| else: | |||||
| self.remove_key(cached_file) | |||||
| return None | |||||
| def __get_cache_key(self, model_file_info): | |||||
| cache_key = { | |||||
| 'Path': model_file_info['Path'], | |||||
| 'Revision': model_file_info['Revision'], # commit id | |||||
| } | |||||
| return cache_key | |||||
| def exists(self, model_file_info): | |||||
| """Check the file is cached or not. | |||||
| Args: | |||||
| model_file_info (CachedFileInfo): The cached file info | |||||
| Returns: | |||||
| bool: If exists return True otherwise False | |||||
| """ | |||||
| key = self.__get_cache_key(model_file_info) | |||||
| is_exists = False | |||||
| for cached_key in self.cached_files: | |||||
| if cached_key['Path'] == key['Path'] and ( | |||||
| cached_key['Revision'].startswith(key['Revision']) | |||||
| or key['Revision'].startswith(cached_key['Revision'])): | |||||
| is_exists = True | |||||
| file_path = os.path.join(self.cache_root_location, | |||||
| model_file_info['Path']) | |||||
| if is_exists: | |||||
| if os.path.exists(file_path): | |||||
| return True | |||||
| else: | |||||
| self.remove_key( | |||||
| model_file_info) # sameone may manual delete the file | |||||
| return False | |||||
| def remove_if_exists(self, model_file_info): | |||||
| """We in cache, remove it. | |||||
| Args: | |||||
| model_file_info (ModelFileInfo): The model file information from server. | |||||
| """ | |||||
| for cached_file in self.cached_files: | |||||
| if cached_file['Path'] == model_file_info['Path']: | |||||
| self.remove_key(cached_file) | |||||
| file_path = os.path.join(self.cache_root_location, | |||||
| cached_file['Path']) | |||||
| if os.path.exists(file_path): | |||||
| os.remove(file_path) | |||||
| def put_file(self, model_file_info, model_file_location): | |||||
| """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. | |||||
| Args: | |||||
| model_file_info (str): The file description returned by get_model_files | |||||
| sample: | |||||
| { | |||||
| "CommitMessage": "add model\n", | |||||
| "CommittedDate": 1654857567, | |||||
| "CommitterName": "mulin.lyh", | |||||
| "IsLFS": false, | |||||
| "Mode": "100644", | |||||
| "Name": "resnet18.pth", | |||||
| "Path": "resnet18.pth", | |||||
| "Revision": "09b68012b27de0048ba74003690a890af7aff192", | |||||
| "Size": 46827520, | |||||
| "Type": "blob" | |||||
| } | |||||
| model_file_location (str): The location of the temporary file. | |||||
| Raises: | |||||
| NotImplementedError: _description_ | |||||
| Returns: | |||||
| str: The location of the cached file. | |||||
| """ | |||||
| self.remove_if_exists(model_file_info) # backup old revision | |||||
| cache_key = self.__get_cache_key(model_file_info) | |||||
| cache_full_path = os.path.join( | |||||
| self.cache_root_location, | |||||
| cache_key['Path']) # Branch and Tag do not have same name. | |||||
| cache_file_dir = os.path.dirname(cache_full_path) | |||||
| if not os.path.exists(cache_file_dir): | |||||
| os.makedirs(cache_file_dir, exist_ok=True) | |||||
| # We can't make operation transaction | |||||
| move(model_file_location, cache_full_path) | |||||
| self.cached_files.append(cache_key) | |||||
| self.save_cached_files() | |||||
| return cache_full_path | |||||
| @@ -0,0 +1,39 @@ | |||||
| import os | |||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||||
| DEFAULT_MODELSCOPE_GITLAB_DOMAIN, | |||||
| DEFAULT_MODELSCOPE_GROUP, | |||||
| MODEL_ID_SEPARATOR, | |||||
| MODELSCOPE_URL_SCHEME) | |||||
| def model_id_to_group_owner_name(model_id): | |||||
| if MODEL_ID_SEPARATOR in model_id: | |||||
| group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0] | |||||
| name = model_id.split(MODEL_ID_SEPARATOR)[1] | |||||
| else: | |||||
| group_or_owner = DEFAULT_MODELSCOPE_GROUP | |||||
| name = model_id | |||||
| return group_or_owner, name | |||||
| def get_cache_dir(): | |||||
| """ | |||||
| cache dir precedence: | |||||
| function parameter > enviroment > ~/.cache/modelscope/hub | |||||
| """ | |||||
| default_cache_dir = os.path.expanduser( | |||||
| os.path.join('~/.cache', 'modelscope')) | |||||
| return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir, | |||||
| 'hub')) | |||||
| def get_endpoint(): | |||||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||||
| DEFAULT_MODELSCOPE_DOMAIN) | |||||
| return MODELSCOPE_URL_SCHEME + modelscope_domain | |||||
| def get_gitlab_domain(): | |||||
| return os.getenv('MODELSCOPE_GITLAB_DOMAIN', | |||||
| DEFAULT_MODELSCOPE_GITLAB_DOMAIN) | |||||
| @@ -0,0 +1,94 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| class Models(object): | |||||
| """ Names for different models. | |||||
| Holds the standard model name to use for identifying different model. | |||||
| This should be used to register models. | |||||
| Model name should only contain model info but not task info. | |||||
| """ | |||||
| # vision models | |||||
| # nlp models | |||||
| bert = 'bert' | |||||
| palm2_0 = 'palm2.0' | |||||
| structbert = 'structbert' | |||||
| # audio models | |||||
| sambert_hifi_16k = 'sambert-hifi-16k' | |||||
| generic_tts_frontend = 'generic-tts-frontend' | |||||
| hifigan16k = 'hifigan16k' | |||||
| # multi-modal models | |||||
| ofa = 'ofa' | |||||
| class Pipelines(object): | |||||
| """ Names for different pipelines. | |||||
| Holds the standard pipline name to use for identifying different pipeline. | |||||
| This should be used to register pipelines. | |||||
| For pipeline which support different models and implements the common function, we | |||||
| should use task name for this pipeline. | |||||
| For pipeline which suuport only one model, we should use ${Model}-${Task} as its name. | |||||
| """ | |||||
| # vision tasks | |||||
| image_matting = 'unet-image-matting' | |||||
| person_image_cartoon = 'unet-person-image-cartoon' | |||||
| ocr_detection = 'resnet18-ocr-detection' | |||||
| # nlp tasks | |||||
| sentence_similarity = 'sentence-similarity' | |||||
| word_segmentation = 'word-segmentation' | |||||
| text_generation = 'text-generation' | |||||
| sentiment_analysis = 'sentiment-analysis' | |||||
| # audio tasks | |||||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | |||||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | |||||
| # multi-modal tasks | |||||
| image_caption = 'image-caption' | |||||
| class Trainers(object): | |||||
| """ Names for different trainer. | |||||
| Holds the standard trainer name to use for identifying different trainer. | |||||
| This should be used to register trainers. | |||||
| For a general Trainer, you can use easynlp-trainer/ofa-trainer/sofa-trainer. | |||||
| For a model specific Trainer, you can use ${ModelName}-${Task}-trainer. | |||||
| """ | |||||
| default = 'Trainer' | |||||
| class Preprocessors(object): | |||||
| """ Names for different preprocessor. | |||||
| Holds the standard preprocessor name to use for identifying different preprocessor. | |||||
| This should be used to register preprocessors. | |||||
| For a general preprocessor, just use the function name as preprocessor name such as | |||||
| resize-image, random-crop | |||||
| For a model-specific preprocessor, use ${modelname}-${fuction} | |||||
| """ | |||||
| # cv preprocessor | |||||
| load_image = 'load-image' | |||||
| # nlp preprocessor | |||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||||
| palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | |||||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||||
| # audio preprocessor | |||||
| linear_aec_fbank = 'linear-aec-fbank' | |||||
| text_to_tacotron_symbols = 'text-to-tacotron-symbols' | |||||
| # multi-modal | |||||
| ofa_image_caption = 'ofa-image-caption' | |||||
| @@ -6,6 +6,7 @@ import numpy as np | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from sklearn.preprocessing import MultiLabelBinarizer | from sklearn.preprocessing import MultiLabelBinarizer | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| @@ -26,7 +27,8 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): | |||||
| return one_hot.fit_transform(sequences) | return one_hot.fit_transform(sequences) | ||||
| @MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k') | |||||
| @MODELS.register_module( | |||||
| Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) | |||||
| class SambertNetHifi16k(Model): | class SambertNetHifi16k(Model): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -2,6 +2,7 @@ import os | |||||
| import zipfile | import zipfile | ||||
| from typing import Any, Dict, List | from typing import Any, Dict, List | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.audio.tts_exceptions import ( | from modelscope.utils.audio.tts_exceptions import ( | ||||
| @@ -13,7 +14,7 @@ __all__ = ['GenericTtsFrontend'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.text_to_speech, module_name=r'generic_tts_frontend') | |||||
| Tasks.text_to_speech, module_name=Models.generic_tts_frontend) | |||||
| class GenericTtsFrontend(Model): | class GenericTtsFrontend(Model): | ||||
| def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): | def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): | ||||
| @@ -10,6 +10,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from scipy.io.wavfile import write | from scipy.io.wavfile import write | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.audio.tts_exceptions import \ | from modelscope.utils.audio.tts_exceptions import \ | ||||
| @@ -36,7 +37,7 @@ class AttrDict(dict): | |||||
| self.__dict__ = self | self.__dict__ = self | ||||
| @MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k') | |||||
| @MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k) | |||||
| class Hifigan16k(Model): | class Hifigan16k(Model): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| @@ -4,12 +4,13 @@ import os.path as osp | |||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, Union | from typing import Dict, Union | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -47,21 +48,25 @@ class Model(ABC): | |||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| cache_path = get_model_cache_dir(model_name_or_path) | |||||
| local_model_dir = cache_path if osp.exists( | |||||
| cache_path) else snapshot_download(model_name_or_path) | |||||
| # else: | |||||
| # raise ValueError( | |||||
| # 'Remote model repo {model_name_or_path} does not exists') | |||||
| local_model_dir = snapshot_download(model_name_or_path) | |||||
| logger.info(f'initialize model from {local_model_dir}') | |||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | osp.join(local_model_dir, ModelFile.CONFIGURATION)) | ||||
| task_name = cfg.task | task_name = cfg.task | ||||
| model_cfg = cfg.model | model_cfg = cfg.model | ||||
| assert hasattr( | |||||
| cfg, 'pipeline'), 'pipeline config is missing from config file.' | |||||
| pipeline_cfg = cfg.pipeline | |||||
| # TODO @wenmeng.zwm may should manually initialize model after model building | # TODO @wenmeng.zwm may should manually initialize model after model building | ||||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | ||||
| model_cfg.type = model_cfg.model_type | model_cfg.type = model_cfg.model_type | ||||
| model_cfg.model_dir = local_model_dir | model_cfg.model_dir = local_model_dir | ||||
| for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
| model_cfg.k = v | model_cfg.k = v | ||||
| return build_model(model_cfg, task_name) | |||||
| model = build_model(model_cfg, task_name) | |||||
| # dynamically add pipeline info to model for pipeline inference | |||||
| model.pipeline = pipeline_cfg | |||||
| return model | |||||
| @@ -3,6 +3,7 @@ from typing import Any, Dict | |||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from ..base import Model | from ..base import Model | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -10,8 +11,7 @@ from ..builder import MODELS | |||||
| __all__ = ['OfaForImageCaptioning'] | __all__ = ['OfaForImageCaptioning'] | ||||
| @MODELS.register_module( | |||||
| Tasks.image_captioning, module_name=r'ofa-image-captioning') | |||||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||||
| class OfaForImageCaptioning(Model): | class OfaForImageCaptioning(Model): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| @@ -4,6 +4,7 @@ from typing import Any, Dict | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ..base import Model | from ..base import Model | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -11,8 +12,7 @@ from ..builder import MODELS | |||||
| __all__ = ['BertForSequenceClassification'] | __all__ = ['BertForSequenceClassification'] | ||||
| @MODELS.register_module( | |||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||||
| class BertForSequenceClassification(Model): | class BertForSequenceClassification(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import Dict | from typing import Dict | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -7,7 +8,7 @@ from ..builder import MODELS | |||||
| __all__ = ['PalmForTextGeneration'] | __all__ = ['PalmForTextGeneration'] | ||||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||||
| @MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0) | |||||
| class PalmForTextGeneration(Model): | class PalmForTextGeneration(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -8,6 +8,7 @@ from sofa import SbertModel | |||||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | ||||
| from torch import nn | from torch import nn | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -38,8 +39,7 @@ class SbertTextClassifier(SbertPreTrainedModel): | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.sentence_similarity, | |||||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||||
| Tasks.sentence_similarity, module_name=Models.structbert) | |||||
| class SbertForSentenceSimilarity(Model): | class SbertForSentenceSimilarity(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -4,6 +4,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from sofa import SbertConfig, SbertForTokenClassification | from sofa import SbertConfig, SbertForTokenClassification | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -11,9 +12,7 @@ from ..builder import MODELS | |||||
| __all__ = ['StructBertForTokenClassification'] | __all__ = ['StructBertForTokenClassification'] | ||||
| @MODELS.register_module( | |||||
| Tasks.word_segmentation, | |||||
| module_name=r'structbert-chinese-word-segmentation') | |||||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | |||||
| class StructBertForTokenClassification(Model): | class StructBertForTokenClassification(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -7,6 +7,7 @@ import scipy.io.wavfile as wav | |||||
| import torch | import torch | ||||
| import yaml | import yaml | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.preprocessors.audio import LinearAECAndFbank | from modelscope.preprocessors.audio import LinearAECAndFbank | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| @@ -39,7 +40,8 @@ def initialize_config(module_cfg): | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k') | |||||
| Tasks.speech_signal_process, | |||||
| module_name=Pipelines.speech_dfsmn_aec_psm_16k) | |||||
| class LinearAECPipeline(Pipeline): | class LinearAECPipeline(Pipeline): | ||||
| r"""AEC Inference Pipeline only support 16000 sample rate. | r"""AEC Inference Pipeline only support 16000 sample rate. | ||||
| @@ -3,6 +3,7 @@ from typing import Any, Dict, List | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.audio.tts.am import SambertNetHifi16k | from modelscope.models.audio.tts.am import SambertNetHifi16k | ||||
| from modelscope.models.audio.tts.vocoder import Hifigan16k | from modelscope.models.audio.tts.vocoder import Hifigan16k | ||||
| @@ -15,7 +16,7 @@ __all__ = ['TextToSpeechSambertHifigan16kPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k') | |||||
| Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts) | |||||
| class TextToSpeechSambertHifigan16kPipeline(Pipeline): | class TextToSpeechSambertHifigan16kPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -4,16 +4,14 @@ import os.path as osp | |||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Any, Dict, Generator, List, Union | from typing import Any, Dict, Generator, List, Union | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.preprocessors import Preprocessor | from modelscope.preprocessors import Preprocessor | ||||
| from modelscope.pydatasets import PyDataset | from modelscope.pydatasets import PyDataset | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .outputs import TASK_OUTPUTS | from .outputs import TASK_OUTPUTS | ||||
| from .util import is_model_name | |||||
| from .util import is_model, is_official_hub_path | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | ||||
| @@ -29,14 +27,10 @@ class Pipeline(ABC): | |||||
| def initiate_single_model(self, model): | def initiate_single_model(self, model): | ||||
| logger.info(f'initiate model from {model}') | logger.info(f'initiate model from {model}') | ||||
| # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||||
| if isinstance(model, str) and model.startswith('damo/'): | |||||
| if not osp.exists(model): | |||||
| cache_path = get_model_cache_dir(model) | |||||
| model = cache_path if osp.exists( | |||||
| cache_path) else snapshot_download(model) | |||||
| return Model.from_pretrained(model) if is_model_name( | |||||
| model) else model | |||||
| if isinstance(model, str) and is_official_hub_path(model): | |||||
| model = snapshot_download( | |||||
| model) if not osp.exists(model) else model | |||||
| return Model.from_pretrained(model) if is_model(model) else model | |||||
| elif isinstance(model, Model): | elif isinstance(model, Model): | ||||
| return model | return model | ||||
| else: | else: | ||||
| @@ -3,32 +3,39 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | from typing import List, Union | ||||
| from attr import has | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.hub import read_config | |||||
| from modelscope.utils.registry import Registry, build_from_cfg | from modelscope.utils.registry import Registry, build_from_cfg | ||||
| from .base import Pipeline | from .base import Pipeline | ||||
| from .util import is_official_hub_path | |||||
| PIPELINES = Registry('pipelines') | PIPELINES = Registry('pipelines') | ||||
| DEFAULT_MODEL_FOR_PIPELINE = { | DEFAULT_MODEL_FOR_PIPELINE = { | ||||
| # TaskName: (pipeline_module_name, model_repo) | # TaskName: (pipeline_module_name, model_repo) | ||||
| Tasks.word_segmentation: | Tasks.word_segmentation: | ||||
| ('structbert-chinese-word-segmentation', | |||||
| (Pipelines.word_segmentation, | |||||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | 'damo/nlp_structbert_word-segmentation_chinese-base'), | ||||
| Tasks.sentence_similarity: | Tasks.sentence_similarity: | ||||
| ('sbert-base-chinese-sentence-similarity', | |||||
| (Pipelines.sentence_similarity, | |||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | |||||
| Tasks.text_classification: | |||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||||
| Tasks.text_generation: ('palm2.0', | |||||
| Tasks.image_matting: | |||||
| (Pipelines.image_matting, 'damo/cv_unet_image-matting'), | |||||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||||
| 'damo/bert-base-sst2'), | |||||
| Tasks.text_generation: (Pipelines.text_generation, | |||||
| 'damo/nlp_palm2.0_text-generation_chinese-base'), | 'damo/nlp_palm2.0_text-generation_chinese-base'), | ||||
| Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), | |||||
| Tasks.image_captioning: (Pipelines.image_caption, | |||||
| 'damo/ofa_image-caption_coco_large_en'), | |||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| ('person-image-cartoon', | |||||
| (Pipelines.person_image_cartoon, | |||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| Tasks.ocr_detection: ('ocr-detection', | |||||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | |||||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | 'damo/cv_resnet18_ocr-detection-line-level_damo'), | ||||
| Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large') | Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large') | ||||
| } | } | ||||
| @@ -87,30 +94,40 @@ def pipeline(task: str = None, | |||||
| if task is None and pipeline_name is None: | if task is None and pipeline_name is None: | ||||
| raise ValueError('task or pipeline_name is required') | raise ValueError('task or pipeline_name is required') | ||||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||||
| if pipeline_name is None: | if pipeline_name is None: | ||||
| # get default pipeline for this task | # get default pipeline for this task | ||||
| if isinstance(model, str) \ | if isinstance(model, str) \ | ||||
| or (isinstance(model, list) and isinstance(model[0], str)): | or (isinstance(model, list) and isinstance(model[0], str)): | ||||
| # if is_model_name(model): | |||||
| if (isinstance(model, str) and model.startswith('damo/')) \ | |||||
| or (isinstance(model, list) and model[0].startswith('damo/')) \ | |||||
| or (isinstance(model, str) and osp.exists(model)): | |||||
| # TODO @wenmeng.zwm add support when model is a str of modelhub address | |||||
| # read pipeline info from modelhub configuration file. | |||||
| pipeline_name, default_model_repo = get_default_pipeline_info( | |||||
| task) | |||||
| if is_official_hub_path(model): | |||||
| # read config file from hub and parse | |||||
| cfg = read_config(model) if isinstance( | |||||
| model, str) else read_config(model[0]) | |||||
| assert hasattr( | |||||
| cfg, | |||||
| 'pipeline'), 'pipeline config is missing from config file.' | |||||
| pipeline_name = cfg.pipeline.type | |||||
| else: | else: | ||||
| # used for test case, when model is str and is not hub path | |||||
| pipeline_name = get_pipeline_by_model_name(task, model) | pipeline_name = get_pipeline_by_model_name(task, model) | ||||
| elif isinstance(model, Model) or \ | |||||
| (isinstance(model, list) and isinstance(model[0], Model)): | |||||
| # get pipeline info from Model object | |||||
| first_model = model[0] if isinstance(model, list) else model | |||||
| if not hasattr(first_model, 'pipeline'): | |||||
| # model is instantiated by user, we should parse config again | |||||
| cfg = read_config(first_model.model_dir) | |||||
| assert hasattr( | |||||
| cfg, | |||||
| 'pipeline'), 'pipeline config is missing from config file.' | |||||
| first_model.pipeline = cfg.pipeline | |||||
| pipeline_name = first_model.pipeline.type | |||||
| else: | else: | ||||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | pipeline_name, default_model_repo = get_default_pipeline_info(task) | ||||
| if model is None: | |||||
| model = default_model_repo | model = default_model_repo | ||||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||||
| cfg = ConfigDict(type=pipeline_name, model=model) | cfg = ConfigDict(type=pipeline_name, model=model) | ||||
| if kwargs: | if kwargs: | ||||
| @@ -6,6 +6,7 @@ import numpy as np | |||||
| import PIL | import PIL | ||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.cartoon.facelib.facer import FaceAna | from modelscope.models.cv.cartoon.facelib.facer import FaceAna | ||||
| from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import ( | from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import ( | ||||
| get_reference_facial_points, warp_and_crop_face) | get_reference_facial_points, warp_and_crop_face) | ||||
| @@ -25,7 +26,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_generation, module_name='person-image-cartoon') | |||||
| Tasks.image_generation, module_name=Pipelines.person_image_cartoon) | |||||
| class ImageCartoonPipeline(Pipeline): | class ImageCartoonPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| @@ -5,6 +5,7 @@ import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| import PIL | import PIL | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| @@ -16,7 +17,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||||
| Tasks.image_matting, module_name=Pipelines.image_matting) | |||||
| class ImageMattingPipeline(Pipeline): | class ImageMattingPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| @@ -10,6 +10,7 @@ import PIL | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| import tf_slim as slim | import tf_slim as slim | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| @@ -38,7 +39,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.ocr_detection, module_name=Tasks.ocr_detection) | |||||
| Tasks.ocr_detection, module_name=Pipelines.ocr_detection) | |||||
| class OCRDetectionPipeline(Pipeline): | class OCRDetectionPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor | from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -9,7 +10,8 @@ from ..builder import PIPELINES | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_captioning, module_name=Pipelines.image_caption) | |||||
| class ImageCaptionPipeline(Pipeline): | class ImageCaptionPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -2,6 +2,7 @@ from typing import Any, Dict, Union | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import SbertForSentenceSimilarity | from modelscope.models.nlp import SbertForSentenceSimilarity | ||||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | from modelscope.preprocessors import SequenceClassificationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -13,8 +14,7 @@ __all__ = ['SentenceSimilarityPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.sentence_similarity, | |||||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||||
| Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) | |||||
| class SentenceSimilarityPipeline(Pipeline): | class SentenceSimilarityPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -2,6 +2,7 @@ from typing import Any, Dict, Union | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import BertForSequenceClassification | from modelscope.models.nlp import BertForSequenceClassification | ||||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | from modelscope.preprocessors import SequenceClassificationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -13,7 +14,7 @@ __all__ = ['SequenceClassificationPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
| Tasks.text_classification, module_name=Pipelines.sentiment_analysis) | |||||
| class SequenceClassificationPipeline(Pipeline): | class SequenceClassificationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import PalmForTextGeneration | from modelscope.models.nlp import PalmForTextGeneration | ||||
| from modelscope.preprocessors import TextGenerationPreprocessor | from modelscope.preprocessors import TextGenerationPreprocessor | ||||
| @@ -10,7 +11,8 @@ from ..builder import PIPELINES | |||||
| __all__ = ['TextGenerationPipeline'] | __all__ = ['TextGenerationPipeline'] | ||||
| @PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text_generation, module_name=Pipelines.text_generation) | |||||
| class TextGenerationPipeline(Pipeline): | class TextGenerationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import StructBertForTokenClassification | from modelscope.models.nlp import StructBertForTokenClassification | ||||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | from modelscope.preprocessors import TokenClassifcationPreprocessor | ||||
| @@ -11,8 +12,7 @@ __all__ = ['WordSegmentationPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.word_segmentation, | |||||
| module_name=r'structbert-chinese-word-segmentation') | |||||
| Tasks.word_segmentation, module_name=Pipelines.word_segmentation) | |||||
| class WordSegmentationPipeline(Pipeline): | class WordSegmentationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -2,8 +2,8 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | from typing import List, Union | ||||
| from maas_hub.file_download import model_file_download | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -20,31 +20,63 @@ def is_config_has_model(cfg_file): | |||||
| return False | return False | ||||
| def is_model_name(model: Union[str, List]): | |||||
| """ whether model is a valid modelhub path | |||||
| def is_official_hub_path(path: Union[str, List]): | |||||
| """ Whether path is a official hub name or a valid local | |||||
| path to official hub directory. | |||||
| """ | |||||
| def is_official_hub_impl(path): | |||||
| if osp.exists(path): | |||||
| cfg_file = osp.join(path, ModelFile.CONFIGURATION) | |||||
| return osp.exists(cfg_file) | |||||
| else: | |||||
| try: | |||||
| _ = HubApi().get_model(path) | |||||
| return True | |||||
| except Exception: | |||||
| return False | |||||
| if isinstance(path, str): | |||||
| return is_official_hub_impl(path) | |||||
| else: | |||||
| results = [is_official_hub_impl(m) for m in path] | |||||
| all_true = all(results) | |||||
| any_true = any(results) | |||||
| if any_true and not all_true: | |||||
| raise ValueError( | |||||
| f'some model are hub address, some are not, model list: {path}' | |||||
| ) | |||||
| return all_true | |||||
| def is_model(path: Union[str, List]): | |||||
| """ whether path is a valid modelhub path and containing model config | |||||
| """ | """ | ||||
| def is_model_name_impl(model): | |||||
| if osp.exists(model): | |||||
| cfg_file = osp.join(model, ModelFile.CONFIGURATION) | |||||
| def is_modelhub_path_impl(path): | |||||
| if osp.exists(path): | |||||
| cfg_file = osp.join(path, ModelFile.CONFIGURATION) | |||||
| if osp.exists(cfg_file): | if osp.exists(cfg_file): | ||||
| return is_config_has_model(cfg_file) | return is_config_has_model(cfg_file) | ||||
| else: | else: | ||||
| return False | return False | ||||
| else: | else: | ||||
| try: | try: | ||||
| cfg_file = model_file_download(model, ModelFile.CONFIGURATION) | |||||
| cfg_file = model_file_download(path, ModelFile.CONFIGURATION) | |||||
| return is_config_has_model(cfg_file) | return is_config_has_model(cfg_file) | ||||
| except Exception: | except Exception: | ||||
| return False | return False | ||||
| if isinstance(model, str): | |||||
| return is_model_name_impl(model) | |||||
| if isinstance(path, str): | |||||
| return is_modelhub_path_impl(path) | |||||
| else: | else: | ||||
| results = [is_model_name_impl(m) for m in model] | |||||
| results = [is_modelhub_path_impl(m) for m in path] | |||||
| all_true = all(results) | all_true = all(results) | ||||
| any_true = any(results) | any_true = any(results) | ||||
| if any_true and not all_true: | if any_true and not all_true: | ||||
| raise ValueError('some model are hub address, some are not') | |||||
| raise ValueError( | |||||
| f'some models are hub address, some are not, model list: {path}' | |||||
| ) | |||||
| return all_true | return all_true | ||||
| @@ -5,11 +5,12 @@ from typing import Dict, Union | |||||
| from PIL import Image, ImageOps | from PIL import Image, ImageOps | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.utils.constant import Fields | from modelscope.utils.constant import Fields | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| @PREPROCESSORS.register_module(Fields.cv) | |||||
| @PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image) | |||||
| class LoadImage: | class LoadImage: | ||||
| """Load an image from file or url. | """Load an image from file or url. | ||||
| Added or updated keys are "filename", "img", "img_shape", | Added or updated keys are "filename", "img", "img_shape", | ||||
| @@ -4,11 +4,11 @@ from typing import Any, Dict, Union | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.utils.constant import Fields, ModelFile | from modelscope.utils.constant import Fields, ModelFile | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| @@ -20,7 +20,7 @@ __all__ = [ | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.multi_modal, module_name=r'ofa-image-caption') | |||||
| Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) | |||||
| class OfaImageCaptionPreprocessor(Preprocessor): | class OfaImageCaptionPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -34,9 +34,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||||
| if osp.exists(model_dir): | if osp.exists(model_dir): | ||||
| local_model_dir = model_dir | local_model_dir = model_dir | ||||
| else: | else: | ||||
| cache_path = get_model_cache_dir(model_dir) | |||||
| local_model_dir = cache_path if osp.exists( | |||||
| cache_path) else snapshot_download(model_dir) | |||||
| local_model_dir = snapshot_download(model_dir) | |||||
| local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE) | ||||
| bpe_dir = local_model_dir | bpe_dir = local_model_dir | ||||
| @@ -5,6 +5,7 @@ from typing import Any, Dict, Union | |||||
| from transformers import AutoTokenizer | from transformers import AutoTokenizer | ||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.utils.constant import Fields, InputFields | from modelscope.utils.constant import Fields, InputFields | ||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| @@ -32,7 +33,7 @@ class Tokenize(Preprocessor): | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=r'bert-sequence-classification') | |||||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | |||||
| class SequenceClassificationPreprocessor(Preprocessor): | class SequenceClassificationPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -125,7 +126,8 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||||
| return rst | return rst | ||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0') | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) | |||||
| class TextGenerationPreprocessor(Preprocessor): | class TextGenerationPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, tokenizer, *args, **kwargs): | def __init__(self, model_dir: str, tokenizer, *args, **kwargs): | ||||
| @@ -236,7 +238,7 @@ class FillMaskPreprocessor(Preprocessor): | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=r'bert-token-classification') | |||||
| Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | |||||
| class TokenClassifcationPreprocessor(Preprocessor): | class TokenClassifcationPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -3,6 +3,7 @@ import io | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.audio.tts.frontend import GenericTtsFrontend | from modelscope.models.audio.tts.frontend import GenericTtsFrontend | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.audio.tts_exceptions import * # noqa F403 | from modelscope.utils.audio.tts_exceptions import * # noqa F403 | ||||
| @@ -10,11 +11,11 @@ from modelscope.utils.constant import Fields | |||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| __all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols'] | |||||
| __all__ = ['TextToTacotronSymbols'] | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.audio, module_name=r'text_to_tacotron_symbols') | |||||
| Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols) | |||||
| class TextToTacotronSymbols(Preprocessor): | class TextToTacotronSymbols(Preprocessor): | ||||
| """extract tacotron symbols from text. | """extract tacotron symbols from text. | ||||
| @@ -1,14 +1,49 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import os.path as osp | |||||
| from typing import List, Union | |||||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||||
| from numpy import deprecate | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.utils.utils import get_cache_dir | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile | |||||
| # temp solution before the hub-cache is in place | # temp solution before the hub-cache is in place | ||||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||||
| model_id_expanded = model_id.replace('/', | |||||
| MODEL_ID_SEPARATOR) + '.' + branch | |||||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||||
| return os.getenv('MAAS_CACHE', | |||||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||||
| @deprecate | |||||
| def get_model_cache_dir(model_id: str): | |||||
| return os.path.join(get_cache_dir(), model_id) | |||||
| def read_config(model_id_or_path: str): | |||||
| """ Read config from hub or local path | |||||
| Args: | |||||
| model_id_or_path (str): Model repo name or local directory path. | |||||
| Return: | |||||
| config (:obj:`Config`): config object | |||||
| """ | |||||
| if not os.path.exists(model_id_or_path): | |||||
| local_path = model_file_download(model_id_or_path, | |||||
| ModelFile.CONFIGURATION) | |||||
| else: | |||||
| local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) | |||||
| return Config.from_file(local_path) | |||||
| def auto_load(model: Union[str, List[str]]): | |||||
| if isinstance(model, str): | |||||
| if not osp.exists(model): | |||||
| model = snapshot_download(model) | |||||
| else: | |||||
| model = [ | |||||
| snapshot_download(m) if not osp.exists(m) else m for m in model | |||||
| ] | |||||
| return model | |||||
| @@ -1,10 +1,10 @@ | |||||
| #tts | #tts | ||||
| h5py==2.10.0 | h5py==2.10.0 | ||||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl | |||||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl | |||||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl; python_version=='3.6' | |||||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl; python_version=='3.7' | |||||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl; python_version=='3.8' | |||||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl; python_version=='3.9' | |||||
| https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D | https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D | ||||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl | |||||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl | |||||
| inflect | inflect | ||||
| keras==2.2.4 | keras==2.2.4 | ||||
| librosa | librosa | ||||
| @@ -12,7 +12,7 @@ lxml | |||||
| matplotlib | matplotlib | ||||
| nara_wpe | nara_wpe | ||||
| numpy==1.18.* | numpy==1.18.* | ||||
| protobuf==3.20.* | |||||
| protobuf>3,<=3.20 | |||||
| ptflops | ptflops | ||||
| PyWavelets>=1.0.0 | PyWavelets>=1.0.0 | ||||
| scikit-learn==0.23.2 | scikit-learn==0.23.2 | ||||
| @@ -1,13 +1,16 @@ | |||||
| addict | addict | ||||
| datasets | datasets | ||||
| easydict | easydict | ||||
| https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl | |||||
| filelock>=3.3.0 | |||||
| numpy | numpy | ||||
| opencv-python-headless | opencv-python-headless | ||||
| Pillow>=6.2.0 | Pillow>=6.2.0 | ||||
| pyyaml | pyyaml | ||||
| requests | requests | ||||
| requests==2.27.1 | |||||
| scipy | scipy | ||||
| setuptools==58.0.4 | |||||
| tokenizers<=0.10.3 | tokenizers<=0.10.3 | ||||
| tqdm>=4.64.0 | |||||
| transformers<=4.16.2 | transformers<=4.16.2 | ||||
| yapf | yapf | ||||
| @@ -0,0 +1,157 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import subprocess | |||||
| import tempfile | |||||
| import unittest | |||||
| import uuid | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.utils.utils import get_gitlab_domain | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | |||||
| DEFAULT_GIT_PATH = 'git' | |||||
| class GitError(Exception): | |||||
| pass | |||||
| # TODO make thest git operation to git library after merge code. | |||||
| def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||||
| response = subprocess.run([git_path, *args], capture_output=True) | |||||
| try: | |||||
| response.check_returncode() | |||||
| return response.stdout.decode('utf8') | |||||
| except subprocess.CalledProcessError as error: | |||||
| raise GitError(error.stderr.decode('utf8')) | |||||
| # for public project, token can None, private repo, there must token. | |||||
| def clone(local_dir: str, token: str, url: str): | |||||
| url = url.replace('//', '//oauth2:%s@' % token) | |||||
| clone_args = '-C %s clone %s' % (local_dir, url) | |||||
| clone_args = clone_args.split(' ') | |||||
| stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||||
| print('stdout: %s' % stdout) | |||||
| def push(local_dir: str, token: str, url: str): | |||||
| url = url.replace('//', '//oauth2:%s@' % token) | |||||
| push_args = '-C %s push %s' % (local_dir, url) | |||||
| push_args = push_args.split(' ') | |||||
| stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||||
| print('stdout: %s' % stdout) | |||||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| class HubOperationTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.old_cwd = os.getcwd() | |||||
| self.api = HubApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| chinese_name=model_chinese_name, | |||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| def tearDown(self): | |||||
| os.chdir(self.old_cwd) | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def test_model_repo_creation(self): | |||||
| # change to proper model names before use | |||||
| try: | |||||
| info = self.api.get_model(model_id=self.model_id) | |||||
| assert info['Name'] == self.model_name | |||||
| except KeyError as ke: | |||||
| if ke.args[0] == 'name': | |||||
| print(f'model {self.model_name} already exists, ignore') | |||||
| else: | |||||
| raise | |||||
| # Note that this can be done via git operation once model repo | |||||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||||
| def test_model_upload(self): | |||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| cmd_args = 'clone %s' % url | |||||
| cmd_args = cmd_args.split(' ') | |||||
| out = run_git_command('git', *cmd_args) | |||||
| print(out) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('touch file1') | |||||
| os.system('git add file1') | |||||
| os.system("git commit -m 'Test'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| def test_download_single_file(self): | |||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| os.system('git clone %s' % url) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('wget %s' % sample_model_url) | |||||
| os.system('git add .') | |||||
| os.system("git commit -m 'Add file'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| assert os.path.exists( | |||||
| os.path.join(temporary_dir, self.model_name, | |||||
| download_model_file_name)) | |||||
| downloaded_file = model_file_download( | |||||
| model_id=self.model_id, file_path=download_model_file_name) | |||||
| mdtime1 = os.path.getmtime(downloaded_file) | |||||
| # download again | |||||
| downloaded_file = model_file_download( | |||||
| model_id=self.model_id, file_path=download_model_file_name) | |||||
| mdtime2 = os.path.getmtime(downloaded_file) | |||||
| assert mdtime1 == mdtime2 | |||||
| def test_snapshot_download(self): | |||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| os.system('git clone %s' % url) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('wget %s' % sample_model_url) | |||||
| os.system('git add .') | |||||
| os.system("git commit -m 'Add file'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| downloaded_file_path = os.path.join(snapshot_path, | |||||
| download_model_file_name) | |||||
| assert os.path.exists(downloaded_file_path) | |||||
| mdtime1 = os.path.getmtime(downloaded_file_path) | |||||
| # download again | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| mdtime2 = os.path.getmtime(downloaded_file_path) | |||||
| assert mdtime1 == mdtime2 | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -10,7 +10,6 @@ from modelscope.fileio import File | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pydatasets import PyDataset | from modelscope.pydatasets import PyDataset | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/cv_unet_image-matting' | self.model_id = 'damo/cv_unet_image-matting' | ||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| @unittest.skip('deprecated, download model from model hub instead') | @unittest.skip('deprecated, download model from model hub instead') | ||||
| def test_run_with_direct_file_download(self): | def test_run_with_direct_file_download(self): | ||||
| @@ -66,7 +60,7 @@ class ImageMattingTest(unittest.TestCase): | |||||
| cv2.imwrite('result.png', result['output_png']) | cv2.imwrite('result.png', result['output_png']) | ||||
| print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_modelscope_dataset(self): | def test_run_with_modelscope_dataset(self): | ||||
| dataset = PyDataset.load('beans', split='train', target='image') | dataset = PyDataset.load('beans', split='train', target='image') | ||||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | img_matting = pipeline(Tasks.image_matting, model=self.model_id) | ||||
| @@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase): | |||||
| print('ocr detection results: ') | print('ocr detection results: ') | ||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | def test_run_modelhub_default_model(self): | ||||
| ocr_detection = pipeline(Tasks.ocr_detection) | ocr_detection = pipeline(Tasks.ocr_detection) | ||||
| self.pipeline_inference(ocr_detection, self.test_image) | self.pipeline_inference(ocr_detection, self.test_image) | ||||
| @@ -2,14 +2,12 @@ | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import SbertForSentenceSimilarity | from modelscope.models.nlp import SbertForSentenceSimilarity | ||||
| from modelscope.pipelines import SentenceSimilarityPipeline, pipeline | from modelscope.pipelines import SentenceSimilarityPipeline, pipeline | ||||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | from modelscope.preprocessors import SequenceClassificationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase): | |||||
| sentence1 = '今天气温比昨天高么?' | sentence1 = '今天气温比昨天高么?' | ||||
| sentence2 = '今天湿度比昨天高么?' | sentence2 = '今天湿度比昨天高么?' | ||||
| def setUp(self) -> None: | |||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run(self): | def test_run(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| @@ -3,9 +3,9 @@ import shutil | |||||
| import unittest | import unittest | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | ||||
| FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | ||||
| @@ -30,11 +30,6 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/speech_dfsmn_aec_psm_16k' | self.model_id = 'damo/speech_dfsmn_aec_psm_16k' | ||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| # A temporary hack to provide c++ lib. Download it first. | # A temporary hack to provide c++ lib. Download it first. | ||||
| download(AEC_LIB_URL, AEC_LIB_FILE) | download(AEC_LIB_URL, AEC_LIB_FILE) | ||||
| @@ -48,7 +43,7 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| aec = pipeline( | aec = pipeline( | ||||
| Tasks.speech_signal_process, | Tasks.speech_signal_process, | ||||
| model=self.model_id, | model=self.model_id, | ||||
| pipeline_name=r'speech_dfsmn_aec_psm_16k') | |||||
| pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) | |||||
| aec(input, output_path='output.wav') | aec(input, output_path='output.wav') | ||||
| @@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline | |||||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | from modelscope.preprocessors import SequenceClassificationPreprocessor | ||||
| from modelscope.pydatasets import PyDataset | from modelscope.pydatasets import PyDataset | ||||
| from modelscope.utils.constant import Hubs, Tasks | from modelscope.utils.constant import Hubs, Tasks | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/bert-base-sst2' | self.model_id = 'damo/bert-base-sst2' | ||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | def predict(self, pipeline_ins: SequenceClassificationPipeline): | ||||
| from easynlp.appzoo import load_dataset | from easynlp.appzoo import load_dataset | ||||
| @@ -44,31 +38,6 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| break | break | ||||
| print(r) | print(r) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||||
| cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||||
| cache_path = Path(cache_path_str) | |||||
| if not cache_path.exists(): | |||||
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |||||
| cache_path.touch(exist_ok=True) | |||||
| with cache_path.open('wb') as ofile: | |||||
| ofile.write(File.read(model_url)) | |||||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||||
| zipf.extractall(cache_path.parent) | |||||
| path = r'.cache/easynlp/' | |||||
| model = BertForSequenceClassification(path) | |||||
| preprocessor = SequenceClassificationPreprocessor( | |||||
| path, first_sequence='sentence', second_sequence=None) | |||||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||||
| self.predict(pipeline1) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.text_classification, model=model, preprocessor=preprocessor) | |||||
| print(pipeline2('Hello world!')) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| @@ -1,8 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import unittest | import unittest | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import PalmForTextGeneration | from modelscope.models.nlp import PalmForTextGeneration | ||||
| from modelscope.pipelines import TextGenerationPipeline, pipeline | from modelscope.pipelines import TextGenerationPipeline, pipeline | ||||
| @@ -11,6 +11,7 @@ import torch | |||||
| from scipy.io.wavfile import write | from scipy.io.wavfile import write | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.metainfo import Pipelines, Preprocessors | |||||
| from modelscope.models import Model, build_model | from modelscope.models import Model, build_model | ||||
| from modelscope.models.audio.tts.am import SambertNetHifi16k | from modelscope.models.audio.tts.am import SambertNetHifi16k | ||||
| from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k | from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k | ||||
| @@ -32,7 +33,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' | voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' | ||||
| cfg_preprocessor = dict( | cfg_preprocessor = dict( | ||||
| type='text_to_tacotron_symbols', | |||||
| type=Preprocessors.text_to_tacotron_symbols, | |||||
| model_name=preprocessor_model_id, | model_name=preprocessor_model_id, | ||||
| lang_type=lang_type) | lang_type=lang_type) | ||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | ||||
| @@ -45,7 +46,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| self.assertTrue(voc is not None) | self.assertTrue(voc is not None) | ||||
| sambert_tts = pipeline( | sambert_tts = pipeline( | ||||
| pipeline_name='tts-sambert-hifigan-16k', | |||||
| pipeline_name=Pipelines.sambert_hifigan_16k_tts, | |||||
| config_file='', | config_file='', | ||||
| model=[am, voc], | model=[am, voc], | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| @@ -2,14 +2,12 @@ | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import StructBertForTokenClassification | from modelscope.models.nlp import StructBertForTokenClassification | ||||
| from modelscope.pipelines import WordSegmentationPipeline, pipeline | from modelscope.pipelines import WordSegmentationPipeline, pipeline | ||||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | from modelscope.preprocessors import TokenClassifcationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.hub import get_model_cache_dir | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | ||||
| sentence = '今天天气不错,适合出去游玩' | sentence = '今天天气不错,适合出去游玩' | ||||
| def setUp(self) -> None: | |||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_by_direct_model_download(self): | def test_run_by_direct_model_download(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| @@ -1,6 +1,7 @@ | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.preprocessors import build_preprocessor | from modelscope.preprocessors import build_preprocessor | ||||
| from modelscope.utils.constant import Fields, InputFields | from modelscope.utils.constant import Fields, InputFields | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -14,7 +15,7 @@ class TtsPreprocessorTest(unittest.TestCase): | |||||
| lang_type = 'pinyin' | lang_type = 'pinyin' | ||||
| text = '今天天气不错,我们去散步吧。' | text = '今天天气不错,我们去散步吧。' | ||||
| cfg = dict( | cfg = dict( | ||||
| type='text_to_tacotron_symbols', | |||||
| type=Preprocessors.text_to_tacotron_symbols, | |||||
| model_name='damo/speech_binary_tts_frontend_resource', | model_name='damo/speech_binary_tts_frontend_resource', | ||||
| lang_type=lang_type) | lang_type=lang_type) | ||||
| preprocessor = build_preprocessor(cfg, Fields.audio) | preprocessor = build_preprocessor(cfg, Fields.audio) | ||||
| @@ -33,6 +33,8 @@ class ImgPreprocessor(Preprocessor): | |||||
| class PyDatasetTest(unittest.TestCase): | class PyDatasetTest(unittest.TestCase): | ||||
| @unittest.skipUnless(test_level() >= 2, | |||||
| 'skip test due to dataset api problem') | |||||
| def test_ds_basic(self): | def test_ds_basic(self): | ||||
| ms_ds_full = PyDataset.load('squad') | ms_ds_full = PyDataset.load('squad') | ||||
| ms_ds_full_hf = hfdata.load_dataset('squad') | ms_ds_full_hf = hfdata.load_dataset('squad') | ||||
| @@ -61,7 +61,7 @@ if __name__ == '__main__': | |||||
| parser.add_argument( | parser.add_argument( | ||||
| '--test_dir', default='tests', help='directory to be tested') | '--test_dir', default='tests', help='directory to be tested') | ||||
| parser.add_argument( | parser.add_argument( | ||||
| '--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0') | |||||
| '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| set_test_level(args.level) | set_test_level(args.level) | ||||
| logger.info(f'TEST LEVEL: {test_level()}') | logger.info(f'TEST LEVEL: {test_level()}') | ||||
| @@ -1,50 +0,0 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| from maas_hub.maas_api import MaasApi | |||||
| from maas_hub.repository import Repository | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| class HubOperationTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = MaasApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| @unittest.skip('to be used for local test only') | |||||
| def test_model_repo_creation(self): | |||||
| # change to proper model names before use | |||||
| model_name = 'cv_unet_person-image-cartoon_compound-models' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'damo' | |||||
| try: | |||||
| self.api.create_model( | |||||
| owner=model_org, | |||||
| name=model_name, | |||||
| chinese_name=model_chinese_name, | |||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| # TODO: support proper name duplication checking | |||||
| except KeyError as ke: | |||||
| if ke.args[0] == 'name': | |||||
| print(f'model {self.model_name} already exists, ignore') | |||||
| else: | |||||
| raise | |||||
| # Note that this can be done via git operation once model repo | |||||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||||
| @unittest.skip('to be used for local test only') | |||||
| def test_model_upload(self): | |||||
| local_path = '/path/to/local/model/directory' | |||||
| assert osp.exists(local_path), 'Local model directory not exist.' | |||||
| repo = Repository(local_dir=local_path) | |||||
| repo.push_to_hub(commit_message='Upload model files') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||