Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10463992master
| @@ -1,8 +1,11 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # yapf: disable | |||
| import datetime | |||
| import os | |||
| import pickle | |||
| import shutil | |||
| import tempfile | |||
| from collections import defaultdict | |||
| from http import HTTPStatus | |||
| from http.cookiejar import CookieJar | |||
| @@ -16,17 +19,25 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
| API_RESPONSE_FIELD_MESSAGE, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH) | |||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||
| ModelVisibility) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, RequestError, | |||
| datahub_raise_on_error, | |||
| handle_http_post_error, | |||
| handle_http_response, is_ok, raise_on_error) | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.utils.utils import (get_endpoint, | |||
| model_id_to_group_owner_name) | |||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION, | |||
| DatasetFormations, DatasetMetaFormats, | |||
| DownloadMode) | |||
| DownloadMode, ModelFile) | |||
| from modelscope.utils.logger import get_logger | |||
| from .errors import (InvalidParameter, NotExistError, RequestError, | |||
| datahub_raise_on_error, handle_http_post_error, | |||
| handle_http_response, is_ok, raise_on_error) | |||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
| # yapf: enable | |||
| logger = get_logger() | |||
| @@ -169,11 +180,106 @@ class HubApi: | |||
| else: | |||
| r.raise_for_status() | |||
| def list_model(self, | |||
| owner_or_group: str, | |||
| page_number=1, | |||
| page_size=10) -> dict: | |||
| """List model in owner or group. | |||
| def push_model(self, | |||
| model_id: str, | |||
| model_dir: str, | |||
| visibility: int = ModelVisibility.PUBLIC, | |||
| license: str = Licenses.APACHE_V2, | |||
| chinese_name: Optional[str] = None, | |||
| commit_message: Optional[str] = 'upload model', | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
| """ | |||
| Upload model from a given directory to given repository. A valid model directory | |||
| must contain a configuration.json file. | |||
| This function upload the files in given directory to given repository. If the | |||
| given repository is not exists in remote, it will automatically create it with | |||
| given visibility, license and chinese_name parameters. If the revision is also | |||
| not exists in remote repository, it will create a new branch for it. | |||
| This function must be called before calling HubApi's login with a valid token | |||
| which can be obtained from ModelScope's website. | |||
| Args: | |||
| model_id (`str`): | |||
| The model id to be uploaded, caller must have write permission for it. | |||
| model_dir(`str`): | |||
| The Absolute Path of the finetune result. | |||
| visibility(`int`, defaults to `0`): | |||
| Visibility of the new created model(1-private, 5-public). If the model is | |||
| not exists in ModelScope, this function will create a new model with this | |||
| visibility and this parameter is required. You can ignore this parameter | |||
| if you make sure the model's existence. | |||
| license(`str`, defaults to `None`): | |||
| License of the new created model(see License). If the model is not exists | |||
| in ModelScope, this function will create a new model with this license | |||
| and this parameter is required. You can ignore this parameter if you | |||
| make sure the model's existence. | |||
| chinese_name(`str`, *optional*, defaults to `None`): | |||
| chinese name of the new created model. | |||
| commit_message(`str`, *optional*, defaults to `None`): | |||
| commit message of the push request. | |||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
| which branch to push. If the branch is not exists, It will create a new | |||
| branch and push to it. | |||
| """ | |||
| if model_id is None: | |||
| raise InvalidParameter('model_id cannot be empty!') | |||
| if model_dir is None: | |||
| raise InvalidParameter('model_dir cannot be empty!') | |||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
| raise InvalidParameter('model_dir must be a valid directory.') | |||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| if not os.path.exists(cfg_file): | |||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException('Must login before upload!') | |||
| files_to_save = os.listdir(model_dir) | |||
| try: | |||
| self.get_model(model_id=model_id) | |||
| except Exception: | |||
| if visibility is None or license is None: | |||
| raise InvalidParameter( | |||
| 'visibility and license cannot be empty if want to create new repo' | |||
| ) | |||
| logger.info('Create new model %s' % model_id) | |||
| self.create_model( | |||
| model_id=model_id, | |||
| visibility=visibility, | |||
| license=license, | |||
| chinese_name=chinese_name) | |||
| tmp_dir = tempfile.mkdtemp() | |||
| git_wrapper = GitCommandWrapper() | |||
| try: | |||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||
| if revision not in branches: | |||
| logger.info('Create new branch %s' % revision) | |||
| git_wrapper.new_branch(tmp_dir, revision) | |||
| git_wrapper.checkout(tmp_dir, revision) | |||
| for f in files_to_save: | |||
| if f[0] != '.': | |||
| src = os.path.join(model_dir, f) | |||
| if os.path.isdir(src): | |||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
| else: | |||
| shutil.copy(src, tmp_dir) | |||
| if not commit_message: | |||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||
| model_id, date) | |||
| repo.push(commit_message=commit_message, branch=revision) | |||
| except Exception: | |||
| raise | |||
| finally: | |||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||
| def list_models(self, | |||
| owner_or_group: str, | |||
| page_number=1, | |||
| page_size=10) -> dict: | |||
| """List models in owner or group. | |||
| Args: | |||
| owner_or_group(`str`): owner or group. | |||
| @@ -11,13 +11,12 @@ from typing import Dict, Optional, Union | |||
| from uuid import uuid4 | |||
| import requests | |||
| from filelock import FileLock | |||
| from tqdm import tqdm | |||
| from modelscope import __version__ | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .constants import FILE_HASH | |||
| from .errors import FileDownloadError, NotExistError | |||
| from .utils.caching import ModelFileSystemCache | |||
| @@ -1,13 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import re | |||
| import subprocess | |||
| from typing import List | |||
| from xmlrpc.client import Boolean | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from .errors import GitError | |||
| logger = get_logger() | |||
| @@ -132,6 +129,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| return response | |||
| def add_user_info(self, repo_base_dir, repo_name): | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| user_name, user_email = ModelScopeConfig.get_user_info() | |||
| if user_name and user_email: | |||
| # config user.name and user.email if exist | |||
| @@ -7,7 +7,6 @@ from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION) | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from .git import GitCommandWrapper | |||
| from .utils.utils import get_endpoint | |||
| @@ -47,6 +46,7 @@ class Repository: | |||
| err_msg = 'a non-default value of revision cannot be empty.' | |||
| raise InvalidParameter(err_msg) | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| if auth_token: | |||
| self.auth_token = auth_token | |||
| else: | |||
| @@ -166,7 +166,7 @@ class DatasetRepository: | |||
| err_msg = 'a non-default value of revision cannot be empty.' | |||
| raise InvalidParameter(err_msg) | |||
| self.revision = revision | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| if auth_token: | |||
| self.auth_token = auth_token | |||
| else: | |||
| @@ -5,9 +5,9 @@ import tempfile | |||
| from pathlib import Path | |||
| from typing import Dict, Optional, Union | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .constants import FILE_HASH | |||
| from .errors import NotExistError | |||
| from .file_download import (get_file_download_url, http_get_file, | |||
| @@ -1,117 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import datetime | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import uuid | |||
| from typing import Dict, Optional | |||
| from uuid import uuid4 | |||
| from filelock import FileLock | |||
| from modelscope import __version__ | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.hub.errors import InvalidParameter, NotLoginException | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def upload_folder(model_id: str, | |||
| model_dir: str, | |||
| visibility: int = 0, | |||
| license: str = None, | |||
| chinese_name: Optional[str] = None, | |||
| commit_message: Optional[str] = None, | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
| """ | |||
| Upload model from a given directory to given repository. A valid model directory | |||
| must contain a configuration.json file. | |||
| This function upload the files in given directory to given repository. If the | |||
| given repository is not exists in remote, it will automatically create it with | |||
| given visibility, license and chinese_name parameters. If the revision is also | |||
| not exists in remote repository, it will create a new branch for it. | |||
| This function must be called before calling HubApi's login with a valid token | |||
| which can be obtained from ModelScope's website. | |||
| Args: | |||
| model_id (`str`): | |||
| The model id to be uploaded, caller must have write permission for it. | |||
| model_dir(`str`): | |||
| The Absolute Path of the finetune result. | |||
| visibility(`int`, defaults to `0`): | |||
| Visibility of the new created model(1-private, 5-public). If the model is | |||
| not exists in ModelScope, this function will create a new model with this | |||
| visibility and this parameter is required. You can ignore this parameter | |||
| if you make sure the model's existence. | |||
| license(`str`, defaults to `None`): | |||
| License of the new created model(see License). If the model is not exists | |||
| in ModelScope, this function will create a new model with this license | |||
| and this parameter is required. You can ignore this parameter if you | |||
| make sure the model's existence. | |||
| chinese_name(`str`, *optional*, defaults to `None`): | |||
| chinese name of the new created model. | |||
| commit_message(`str`, *optional*, defaults to `None`): | |||
| commit message of the push request. | |||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
| which branch to push. If the branch is not exists, It will create a new | |||
| branch and push to it. | |||
| """ | |||
| if model_id is None: | |||
| raise InvalidParameter('model_id cannot be empty!') | |||
| if model_dir is None: | |||
| raise InvalidParameter('model_dir cannot be empty!') | |||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
| raise InvalidParameter('model_dir must be a valid directory.') | |||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| if not os.path.exists(cfg_file): | |||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException('Must login before upload!') | |||
| files_to_save = os.listdir(model_dir) | |||
| api = HubApi() | |||
| try: | |||
| api.get_model(model_id=model_id) | |||
| except Exception: | |||
| if visibility is None or license is None: | |||
| raise InvalidParameter( | |||
| 'visibility and license cannot be empty if want to create new repo' | |||
| ) | |||
| logger.info('Create new model %s' % model_id) | |||
| api.create_model( | |||
| model_id=model_id, | |||
| visibility=visibility, | |||
| license=license, | |||
| chinese_name=chinese_name) | |||
| tmp_dir = tempfile.mkdtemp() | |||
| git_wrapper = GitCommandWrapper() | |||
| try: | |||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||
| if revision not in branches: | |||
| logger.info('Create new branch %s' % revision) | |||
| git_wrapper.new_branch(tmp_dir, revision) | |||
| git_wrapper.checkout(tmp_dir, revision) | |||
| for f in files_to_save: | |||
| if f[0] != '.': | |||
| src = os.path.join(model_dir, f) | |||
| if os.path.isdir(src): | |||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
| else: | |||
| shutil.copy(src, tmp_dir) | |||
| if not commit_message: | |||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||
| model_id, date) | |||
| repo.push(commit_message=commit_message, branch=revision) | |||
| except Exception: | |||
| raise | |||
| finally: | |||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||
| @@ -127,7 +127,7 @@ class HubOperationTest(unittest.TestCase): | |||
| return None | |||
| def test_list_model(self): | |||
| data = self.api.list_model(TEST_MODEL_ORG) | |||
| data = self.api.list_models(TEST_MODEL_ORG) | |||
| assert len(data['Models']) >= 1 | |||
| @@ -9,7 +9,6 @@ from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.errors import HTTPError, NotLoginException | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.upload import upload_folder | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -54,14 +53,14 @@ class HubUploadTest(unittest.TestCase): | |||
| license=Licenses.APACHE_V2) | |||
| os.system("echo '111'>%s" | |||
| % os.path.join(self.finetune_path, 'add1.py')) | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, model_dir=self.finetune_path) | |||
| Repository(model_dir=self.repo_path, clone_from=self.create_model_name) | |||
| assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) | |||
| shutil.rmtree(self.repo_path, ignore_errors=True) | |||
| os.system("echo '222'>%s" | |||
| % os.path.join(self.finetune_path, 'add2.py')) | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, | |||
| model_dir=self.finetune_path, | |||
| revision='new_revision/version1') | |||
| @@ -73,7 +72,7 @@ class HubUploadTest(unittest.TestCase): | |||
| shutil.rmtree(self.repo_path, ignore_errors=True) | |||
| os.system("echo '333'>%s" | |||
| % os.path.join(self.finetune_path, 'add3.py')) | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, | |||
| model_dir=self.finetune_path, | |||
| revision='new_revision/version2', | |||
| @@ -88,7 +87,7 @@ class HubUploadTest(unittest.TestCase): | |||
| add4_path = os.path.join(self.finetune_path, 'temp') | |||
| os.mkdir(add4_path) | |||
| os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, | |||
| model_dir=self.finetune_path, | |||
| revision='new_revision/version1') | |||
| @@ -105,7 +104,7 @@ class HubUploadTest(unittest.TestCase): | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| os.system("echo '111'>%s" | |||
| % os.path.join(self.finetune_path, 'add1.py')) | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, | |||
| model_dir=self.finetune_path, | |||
| revision='new_model_new_revision', | |||
| @@ -124,7 +123,7 @@ class HubUploadTest(unittest.TestCase): | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| delete_credential() | |||
| with self.assertRaises(NotLoginException): | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id=self.create_model_name, | |||
| model_dir=self.finetune_path, | |||
| visibility=ModelVisibility.PUBLIC, | |||
| @@ -135,7 +134,7 @@ class HubUploadTest(unittest.TestCase): | |||
| logger.info('test upload to invalid repo!') | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| with self.assertRaises(HTTPError): | |||
| upload_folder( | |||
| self.api.push_model( | |||
| model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | |||
| model_dir=self.finetune_path, | |||
| visibility=ModelVisibility.PUBLIC, | |||