修复测试bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9186775
* [to #42849800 #42822853 #42822836 #42822791 #42822717 #42820011]fix: test bugs
master
| @@ -9,7 +9,7 @@ import requests | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .constants import MODELSCOPE_URL_SCHEME | from .constants import MODELSCOPE_URL_SCHEME | ||||
| from .errors import NotExistError, is_ok, raise_on_error | |||||
| from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error | |||||
| from .utils.utils import (get_endpoint, get_gitlab_domain, | from .utils.utils import (get_endpoint, get_gitlab_domain, | ||||
| model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
| @@ -61,17 +61,21 @@ class HubApi: | |||||
| return d['Data']['AccessToken'], cookies | return d['Data']['AccessToken'], cookies | ||||
| def create_model(self, model_id: str, chinese_name: str, visibility: int, | |||||
| license: str) -> str: | |||||
| def create_model( | |||||
| self, | |||||
| model_id: str, | |||||
| visibility: str, | |||||
| license: str, | |||||
| chinese_name: Optional[str] = None, | |||||
| ) -> str: | |||||
| """ | """ | ||||
| Create model repo at ModelScopeHub | Create model repo at ModelScopeHub | ||||
| Args: | Args: | ||||
| model_id:(`str`): The model id | model_id:(`str`): The model id | ||||
| chinese_name(`str`): chinese name of the model | |||||
| visibility(`int`): visibility of the model(1-private, 3-internal, 5-public) | |||||
| license(`str`): license of the model, candidates can be found at: TBA | |||||
| visibility(`int`): visibility of the model(1-private, 5-public), default public. | |||||
| license(`str`): license of the model, default none. | |||||
| chinese_name(`str`, *optional*): chinese name of the model | |||||
| Returns: | Returns: | ||||
| name of the model created | name of the model created | ||||
| @@ -79,6 +83,8 @@ class HubApi: | |||||
| model_id = {owner}/{name} | model_id = {owner}/{name} | ||||
| </Tip> | </Tip> | ||||
| """ | """ | ||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id is required!') | |||||
| cookies = ModelScopeConfig.get_cookies() | cookies = ModelScopeConfig.get_cookies() | ||||
| if cookies is None: | if cookies is None: | ||||
| raise ValueError('Token does not exist, please login first.') | raise ValueError('Token does not exist, please login first.') | ||||
| @@ -151,11 +157,33 @@ class HubApi: | |||||
| else: | else: | ||||
| r.raise_for_status() | r.raise_for_status() | ||||
| def _check_cookie(self, | |||||
| use_cookies: Union[bool, | |||||
| CookieJar] = False) -> CookieJar: | |||||
| cookies = None | |||||
| if isinstance(use_cookies, CookieJar): | |||||
| cookies = use_cookies | |||||
| elif use_cookies: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise ValueError('Token does not exist, please login first.') | |||||
| return cookies | |||||
| def get_model_branches_and_tags( | def get_model_branches_and_tags( | ||||
| self, | self, | ||||
| model_id: str, | model_id: str, | ||||
| use_cookies: Union[bool, CookieJar] = False | |||||
| ) -> Tuple[List[str], List[str]]: | ) -> Tuple[List[str], List[str]]: | ||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| """Get model branch and tags. | |||||
| Args: | |||||
| model_id (str): The model id | |||||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||||
| will load cookie from local. Defaults to False. | |||||
| Returns: | |||||
| Tuple[List[str], List[str]]: _description_ | |||||
| """ | |||||
| cookies = self._check_cookie(use_cookies) | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | ||||
| r = requests.get(path, cookies=cookies) | r = requests.get(path, cookies=cookies) | ||||
| @@ -169,23 +197,33 @@ class HubApi: | |||||
| ] if info['RevisionMap']['Tags'] else [] | ] if info['RevisionMap']['Tags'] else [] | ||||
| return branches, tags | return branches, tags | ||||
| def get_model_files( | |||||
| self, | |||||
| model_id: str, | |||||
| revision: Optional[str] = 'master', | |||||
| root: Optional[str] = None, | |||||
| recursive: Optional[str] = False, | |||||
| use_cookies: Union[bool, CookieJar] = False) -> List[dict]: | |||||
| def get_model_files(self, | |||||
| model_id: str, | |||||
| revision: Optional[str] = 'master', | |||||
| root: Optional[str] = None, | |||||
| recursive: Optional[str] = False, | |||||
| use_cookies: Union[bool, CookieJar] = False, | |||||
| is_snapshot: Optional[bool] = True) -> List[dict]: | |||||
| """List the models files. | |||||
| cookies = None | |||||
| if isinstance(use_cookies, CookieJar): | |||||
| cookies = use_cookies | |||||
| elif use_cookies: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise ValueError('Token does not exist, please login first.') | |||||
| Args: | |||||
| model_id (str): The model id | |||||
| revision (Optional[str], optional): The branch or tag name. Defaults to 'master'. | |||||
| root (Optional[str], optional): The root path. Defaults to None. | |||||
| recursive (Optional[str], optional): Is recurive list files. Defaults to False. | |||||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||||
| will load cookie from local. Defaults to False. | |||||
| is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False. | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}' | |||||
| Raises: | |||||
| ValueError: If user_cookies is True, but no local cookie. | |||||
| Returns: | |||||
| List[dict]: Model file list. | |||||
| """ | |||||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % ( | |||||
| self.endpoint, model_id, revision, recursive, is_snapshot) | |||||
| cookies = self._check_cookie(use_cookies) | |||||
| if root is not None: | if root is not None: | ||||
| path = path + f'&Root={root}' | path = path + f'&Root={root}' | ||||
| @@ -10,6 +10,10 @@ class GitError(Exception): | |||||
| pass | pass | ||||
| class InvalidParameter(Exception): | |||||
| pass | |||||
| def is_ok(rsp): | def is_ok(rsp): | ||||
| """ Check the request is ok | """ Check the request is ok | ||||
| @@ -7,6 +7,7 @@ import tempfile | |||||
| import time | import time | ||||
| from functools import partial | from functools import partial | ||||
| from hashlib import sha256 | from hashlib import sha256 | ||||
| from http.cookiejar import CookieJar | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import BinaryIO, Dict, Optional, Union | from typing import BinaryIO, Dict, Optional, Union | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| @@ -107,7 +108,9 @@ def model_file_download( | |||||
| _api = HubApi() | _api = HubApi() | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | ||||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| file_to_download_info = None | file_to_download_info = None | ||||
| is_commit_id = False | is_commit_id = False | ||||
| if revision in branches or revision in tags: # The revision is version or tag, | if revision in branches or revision in tags: # The revision is version or tag, | ||||
| @@ -117,18 +120,19 @@ def model_file_download( | |||||
| model_id=model_id, | model_id=model_id, | ||||
| revision=revision, | revision=revision, | ||||
| recursive=True, | recursive=True, | ||||
| ) | |||||
| use_cookies=False if cookies is None else cookies, | |||||
| is_snapshot=False) | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| continue | continue | ||||
| if model_file['Path'] == file_path: | if model_file['Path'] == file_path: | ||||
| model_file['Branch'] = revision | |||||
| if cache.exists(model_file): | if cache.exists(model_file): | ||||
| return cache.get_file_by_info(model_file) | return cache.get_file_by_info(model_file) | ||||
| else: | else: | ||||
| file_to_download_info = model_file | file_to_download_info = model_file | ||||
| break | |||||
| if file_to_download_info is None: | if file_to_download_info is None: | ||||
| raise NotExistError('The file path: %s not exist in: %s' % | raise NotExistError('The file path: %s not exist in: %s' % | ||||
| @@ -141,8 +145,6 @@ def model_file_download( | |||||
| return cached_file_path # the file is in cache. | return cached_file_path # the file is in cache. | ||||
| is_commit_id = True | is_commit_id = True | ||||
| # we need to download again | # we need to download again | ||||
| # TODO: skip using JWT for authorization, use cookie instead | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| url_to_download = get_file_download_url(model_id, file_path, revision) | url_to_download = get_file_download_url(model_id, file_path, revision) | ||||
| file_to_download_info = { | file_to_download_info = { | ||||
| 'Path': file_path, | 'Path': file_path, | ||||
| @@ -202,7 +204,7 @@ def http_get_file( | |||||
| url: str, | url: str, | ||||
| local_dir: str, | local_dir: str, | ||||
| file_name: str, | file_name: str, | ||||
| cookies: Dict[str, str], | |||||
| cookies: CookieJar, | |||||
| headers: Optional[Dict[str, str]] = None, | headers: Optional[Dict[str, str]] = None, | ||||
| ): | ): | ||||
| """ | """ | ||||
| @@ -217,7 +219,7 @@ def http_get_file( | |||||
| local directory where the downloaded file stores | local directory where the downloaded file stores | ||||
| file_name(`str`): | file_name(`str`): | ||||
| name of the file stored in `local_dir` | name of the file stored in `local_dir` | ||||
| cookies(`Dict[str, str]`): | |||||
| cookies(`CookieJar`): | |||||
| cookies used to authentication the user, which is used for downloading private repos | cookies used to authentication the user, which is used for downloading private repos | ||||
| headers(`Optional[Dict[str, str]] = None`): | headers(`Optional[Dict[str, str]] = None`): | ||||
| http headers to carry necessary info when requesting the remote file | http headers to carry necessary info when requesting the remote file | ||||
| @@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| except GitError: | except GitError: | ||||
| return False | return False | ||||
| def git_lfs_install(self, repo_dir): | |||||
| cmd = ['git', '-C', repo_dir, 'lfs', 'install'] | |||||
| try: | |||||
| self._run_git_command(*cmd) | |||||
| return True | |||||
| except GitError: | |||||
| return False | |||||
| def clone(self, | def clone(self, | ||||
| repo_base_dir: str, | repo_base_dir: str, | ||||
| token: str, | token: str, | ||||
| @@ -1,7 +1,7 @@ | |||||
| import os | import os | ||||
| from typing import List, Optional | from typing import List, Optional | ||||
| from modelscope.hub.errors import GitError | |||||
| from modelscope.hub.errors import GitError, InvalidParameter | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | from .api import ModelScopeConfig | ||||
| from .constants import MODELSCOPE_URL_SCHEME | from .constants import MODELSCOPE_URL_SCHEME | ||||
| @@ -49,6 +49,8 @@ class Repository: | |||||
| git_wrapper = GitCommandWrapper() | git_wrapper = GitCommandWrapper() | ||||
| if not git_wrapper.is_lfs_installed(): | if not git_wrapper.is_lfs_installed(): | ||||
| logger.error('git lfs is not installed, please install.') | logger.error('git lfs is not installed, please install.') | ||||
| else: | |||||
| git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | |||||
| self.git_wrapper = GitCommandWrapper(git_path) | self.git_wrapper = GitCommandWrapper(git_path) | ||||
| os.makedirs(self.model_dir, exist_ok=True) | os.makedirs(self.model_dir, exist_ok=True) | ||||
| @@ -74,8 +76,6 @@ class Repository: | |||||
| def push(self, | def push(self, | ||||
| commit_message: str, | commit_message: str, | ||||
| files: List[str] = list(), | |||||
| all_files: bool = False, | |||||
| branch: Optional[str] = 'master', | branch: Optional[str] = 'master', | ||||
| force: bool = False): | force: bool = False): | ||||
| """Push local to remote, this method will do. | """Push local to remote, this method will do. | ||||
| @@ -86,8 +86,12 @@ class Repository: | |||||
| commit_message (str): commit message | commit_message (str): commit message | ||||
| revision (Optional[str], optional): which branch to push. Defaults to 'master'. | revision (Optional[str], optional): which branch to push. Defaults to 'master'. | ||||
| """ | """ | ||||
| if commit_message is None: | |||||
| msg = 'commit_message must be provided!' | |||||
| raise InvalidParameter(msg) | |||||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | url = self.git_wrapper.get_repo_remote_url(self.model_dir) | ||||
| self.git_wrapper.add(self.model_dir, files, all_files) | |||||
| self.git_wrapper.pull(self.model_dir) | |||||
| self.git_wrapper.add(self.model_dir, all_files=True) | |||||
| self.git_wrapper.commit(self.model_dir, commit_message) | self.git_wrapper.commit(self.model_dir, commit_message) | ||||
| self.git_wrapper.push( | self.git_wrapper.push( | ||||
| repo_dir=self.model_dir, | repo_dir=self.model_dir, | ||||
| @@ -20,8 +20,7 @@ def snapshot_download(model_id: str, | |||||
| revision: Optional[str] = 'master', | revision: Optional[str] = 'master', | ||||
| cache_dir: Union[str, Path, None] = None, | cache_dir: Union[str, Path, None] = None, | ||||
| user_agent: Optional[Union[Dict, str]] = None, | user_agent: Optional[Union[Dict, str]] = None, | ||||
| local_files_only: Optional[bool] = False, | |||||
| private: Optional[bool] = False) -> str: | |||||
| local_files_only: Optional[bool] = False) -> str: | |||||
| """Download all files of a repo. | """Download all files of a repo. | ||||
| Downloads a whole snapshot of a repo's files at the specified revision. This | Downloads a whole snapshot of a repo's files at the specified revision. This | ||||
| is useful when you want all files from a repo, because you don't know which | is useful when you want all files from a repo, because you don't know which | ||||
| @@ -79,8 +78,10 @@ def snapshot_download(model_id: str, | |||||
| # make headers | # make headers | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | ||||
| _api = HubApi() | _api = HubApi() | ||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| # get file list from model repo | # get file list from model repo | ||||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| if revision not in branches and revision not in tags: | if revision not in branches and revision not in tags: | ||||
| raise NotExistError('The specified branch or tag : %s not exist!' | raise NotExistError('The specified branch or tag : %s not exist!' | ||||
| % revision) | % revision) | ||||
| @@ -89,11 +90,8 @@ def snapshot_download(model_id: str, | |||||
| model_id=model_id, | model_id=model_id, | ||||
| revision=revision, | revision=revision, | ||||
| recursive=True, | recursive=True, | ||||
| use_cookies=private) | |||||
| cookies = None | |||||
| if private: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| use_cookies=False if cookies is None else cookies, | |||||
| is_snapshot=True) | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| @@ -116,7 +114,7 @@ def snapshot_download(model_id: str, | |||||
| local_dir=tempfile.gettempdir(), | local_dir=tempfile.gettempdir(), | ||||
| file_name=model_file['Name'], | file_name=model_file['Name'], | ||||
| headers=headers, | headers=headers, | ||||
| cookies=None if cookies is None else cookies.get_dict()) | |||||
| cookies=cookies) | |||||
| # put file to cache | # put file to cache | ||||
| cache.put_file( | cache.put_file( | ||||
| model_file, | model_file, | ||||
| @@ -101,8 +101,9 @@ class FileSystemCache(object): | |||||
| Args: | Args: | ||||
| key (dict): The cache key. | key (dict): The cache key. | ||||
| """ | """ | ||||
| self.cached_files.remove(key) | |||||
| self.save_cached_files() | |||||
| if key in self.cached_files: | |||||
| self.cached_files.remove(key) | |||||
| self.save_cached_files() | |||||
| def exists(self, key): | def exists(self, key): | ||||
| for cache_file in self.cached_files: | for cache_file in self.cached_files: | ||||
| @@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache): | |||||
| return orig_path | return orig_path | ||||
| else: | else: | ||||
| self.remove_key(cached_file) | self.remove_key(cached_file) | ||||
| break | |||||
| return None | return None | ||||
| @@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache): | |||||
| cached_key['Revision'].startswith(key['Revision']) | cached_key['Revision'].startswith(key['Revision']) | ||||
| or key['Revision'].startswith(cached_key['Revision'])): | or key['Revision'].startswith(cached_key['Revision'])): | ||||
| is_exists = True | is_exists = True | ||||
| break | |||||
| file_path = os.path.join(self.cache_root_location, | file_path = os.path.join(self.cache_root_location, | ||||
| model_file_info['Path']) | model_file_info['Path']) | ||||
| if is_exists: | if is_exists: | ||||
| @@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache): | |||||
| cached_file['Path']) | cached_file['Path']) | ||||
| if os.path.exists(file_path): | if os.path.exists(file_path): | ||||
| os.remove(file_path) | os.remove(file_path) | ||||
| break | |||||
| def put_file(self, model_file_info, model_file_location): | def put_file(self, model_file_info, model_file_location): | ||||
| """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. | """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. | ||||
| @@ -31,9 +31,10 @@ def create_model_if_not_exist( | |||||
| else: | else: | ||||
| api.create_model( | api.create_model( | ||||
| model_id=model_id, | model_id=model_id, | ||||
| chinese_name=chinese_name, | |||||
| visibility=visibility, | visibility=visibility, | ||||
| license=license) | |||||
| license=license, | |||||
| chinese_name=chinese_name, | |||||
| ) | |||||
| print(f'model {model_id} successfully created.') | print(f'model {model_id} successfully created.') | ||||
| return True | return True | ||||
| @@ -3,6 +3,7 @@ import os | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| import uuid | import uuid | ||||
| from shutil import rmtree | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | from modelscope.hub.api import HubApi, ModelScopeConfig | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
| @@ -23,7 +24,6 @@ download_model_file_name = 'test.bin' | |||||
| class HubOperationTest(unittest.TestCase): | class HubOperationTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| self.old_cwd = os.getcwd() | |||||
| self.api = HubApi() | self.api = HubApi() | ||||
| # note this is temporary before official account management is ready | # note this is temporary before official account management is ready | ||||
| self.api.login(USER_NAME, PASSWORD) | self.api.login(USER_NAME, PASSWORD) | ||||
| @@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase): | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | self.model_id = '%s/%s' % (model_org, self.model_name) | ||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| chinese_name=model_chinese_name, | |||||
| visibility=ModelVisibility.PUBLIC, | visibility=ModelVisibility.PUBLIC, | ||||
| license=Licenses.APACHE_V2) | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=model_chinese_name, | |||||
| ) | |||||
| temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | self.model_dir = os.path.join(temporary_dir, self.model_name) | ||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | repo = Repository(self.model_dir, clone_from=self.model_id) | ||||
| os.chdir(self.model_dir) | |||||
| os.system("echo 'testtest'>%s" | os.system("echo 'testtest'>%s" | ||||
| % os.path.join(self.model_dir, 'test.bin')) | |||||
| repo.push('add model', all_files=True) | |||||
| % os.path.join(self.model_dir, download_model_file_name)) | |||||
| repo.push('add model') | |||||
| def tearDown(self): | def tearDown(self): | ||||
| os.chdir(self.old_cwd) | |||||
| self.api.delete_model(model_id=self.model_id) | self.api.delete_model(model_id=self.model_id) | ||||
| def test_model_repo_creation(self): | def test_model_repo_creation(self): | ||||
| @@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase): | |||||
| mdtime2 = os.path.getmtime(downloaded_file_path) | mdtime2 = os.path.getmtime(downloaded_file_path) | ||||
| assert mdtime1 == mdtime2 | assert mdtime1 == mdtime2 | ||||
| def test_download_public_without_login(self): | |||||
| rmtree(ModelScopeConfig.path_credential) | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| downloaded_file_path = os.path.join(snapshot_path, | |||||
| download_model_file_name) | |||||
| assert os.path.exists(downloaded_file_path) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| downloaded_file = model_file_download( | |||||
| model_id=self.model_id, | |||||
| file_path=download_model_file_name, | |||||
| cache_dir=temporary_dir) | |||||
| assert os.path.exists(downloaded_file) | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| def test_snapshot_delete_download_cache_file(self): | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| downloaded_file_path = os.path.join(snapshot_path, | |||||
| download_model_file_name) | |||||
| assert os.path.exists(downloaded_file_path) | |||||
| os.remove(downloaded_file_path) | |||||
| # download again in cache | |||||
| file_download_path = model_file_download( | |||||
| model_id=self.model_id, file_path='README.md') | |||||
| assert os.path.exists(file_download_path) | |||||
| # deleted file need download again | |||||
| file_download_path = model_file_download( | |||||
| model_id=self.model_id, file_path=download_model_file_name) | |||||
| assert os.path.exists(file_download_path) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import tempfile | |||||
| import unittest | |||||
| import uuid | |||||
| from requests.exceptions import HTTPError | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.errors import GitError | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.utils.constant import ModelFile | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| USER_NAME2 = 'sdkdev' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | |||||
| class HubPrivateFileDownloadTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.old_cwd = os.getcwd() | |||||
| self.api = HubApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=model_chinese_name, | |||||
| ) | |||||
| def tearDown(self): | |||||
| os.chdir(self.old_cwd) | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def test_snapshot_download_private_model(self): | |||||
| snapshot_path = snapshot_download(self.model_id) | |||||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||||
| def test_snapshot_download_private_model_no_permission(self): | |||||
| self.token, _ = self.api.login(USER_NAME2, PASSWORD) | |||||
| with self.assertRaises(HTTPError): | |||||
| snapshot_download(self.model_id) | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| def test_download_file_private_model(self): | |||||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||||
| assert os.path.exists(file_path) | |||||
| def test_download_file_private_model_no_permission(self): | |||||
| self.token, _ = self.api.login(USER_NAME2, PASSWORD) | |||||
| with self.assertRaises(HTTPError): | |||||
| model_file_download(self.model_id, ModelFile.README) | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| def test_snapshot_download_local_only(self): | |||||
| with self.assertRaises(ValueError): | |||||
| snapshot_download(self.model_id, local_files_only=True) | |||||
| snapshot_path = snapshot_download(self.model_id) | |||||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||||
| snapshot_path = snapshot_download(self.model_id, local_files_only=True) | |||||
| assert os.path.exists(snapshot_path) | |||||
| def test_file_download_local_only(self): | |||||
| with self.assertRaises(ValueError): | |||||
| model_file_download( | |||||
| self.model_id, ModelFile.README, local_files_only=True) | |||||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||||
| assert os.path.exists(file_path) | |||||
| file_path = model_file_download( | |||||
| self.model_id, ModelFile.README, local_files_only=True) | |||||
| assert os.path.exists(file_path) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -5,6 +5,7 @@ import unittest | |||||
| import uuid | import uuid | ||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.errors import GitError | from modelscope.hub.errors import GitError | ||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| @@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | model_org = 'unittest' | ||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| class HubPrivateRepositoryTest(unittest.TestCase): | class HubPrivateRepositoryTest(unittest.TestCase): | ||||
| @@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | self.model_id = '%s/%s' % (model_org, self.model_name) | ||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=model_chinese_name, | chinese_name=model_chinese_name, | ||||
| visibility=1, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| ) | |||||
| def tearDown(self): | def tearDown(self): | ||||
| self.api.login(USER_NAME, PASSWORD) | self.api.login(USER_NAME, PASSWORD) | ||||
| @@ -2,7 +2,6 @@ | |||||
| import os | import os | ||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| import time | |||||
| import unittest | import unittest | ||||
| import uuid | import uuid | ||||
| from os.path import expanduser | from os.path import expanduser | ||||
| @@ -10,6 +9,7 @@ from os.path import expanduser | |||||
| from requests import delete | from requests import delete | ||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.errors import NotExistError | from modelscope.hub.errors import NotExistError | ||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| @@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase): | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | self.model_id = '%s/%s' % (model_org, self.model_name) | ||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PUBLIC, # 1-private, 5-public | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=model_chinese_name, | chinese_name=model_chinese_name, | ||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| ) | |||||
| temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | self.model_dir = os.path.join(temporary_dir, self.model_name) | ||||
| @@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase): | |||||
| os.chdir(self.model_dir) | os.chdir(self.model_dir) | ||||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | ||||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | ||||
| repo.push('test', all_files=True) | |||||
| repo.push('test') | |||||
| add1 = model_file_download(self.model_id, 'add1.py') | add1 = model_file_download(self.model_id, 'add1.py') | ||||
| assert os.path.exists(add1) | assert os.path.exists(add1) | ||||
| add2 = model_file_download(self.model_id, 'add2.py') | add2 = model_file_download(self.model_id, 'add2.py') | ||||
| assert os.path.exists(add2) | assert os.path.exists(add2) | ||||
| def test_push_files(self): | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||||
| os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||||
| repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||||
| add1 = model_file_download(self.model_id, 'add1.py') | |||||
| assert os.path.exists(add1) | |||||
| add2 = model_file_download(self.model_id, 'add2.py') | |||||
| assert os.path.exists(add2) | |||||
| with self.assertRaises(NotExistError) as cm: | |||||
| model_file_download(self.model_id, 'add3.py') | |||||
| print(cm.exception) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||