修复测试bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9186775
* [to #42849800 #42822853 #42822836 #42822791 #42822717 #42820011]fix: test bugs
master
| @@ -9,7 +9,7 @@ import requests | |||
| from modelscope.utils.logger import get_logger | |||
| from .constants import MODELSCOPE_URL_SCHEME | |||
| from .errors import NotExistError, is_ok, raise_on_error | |||
| from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error | |||
| from .utils.utils import (get_endpoint, get_gitlab_domain, | |||
| model_id_to_group_owner_name) | |||
| @@ -61,17 +61,21 @@ class HubApi: | |||
| return d['Data']['AccessToken'], cookies | |||
| def create_model(self, model_id: str, chinese_name: str, visibility: int, | |||
| license: str) -> str: | |||
| def create_model( | |||
| self, | |||
| model_id: str, | |||
| visibility: str, | |||
| license: str, | |||
| chinese_name: Optional[str] = None, | |||
| ) -> str: | |||
| """ | |||
| Create model repo at ModelScopeHub | |||
| Args: | |||
| model_id:(`str`): The model id | |||
| chinese_name(`str`): chinese name of the model | |||
| visibility(`int`): visibility of the model(1-private, 3-internal, 5-public) | |||
| license(`str`): license of the model, candidates can be found at: TBA | |||
| visibility(`int`): visibility of the model(1-private, 5-public), default public. | |||
| license(`str`): license of the model, default none. | |||
| chinese_name(`str`, *optional*): chinese name of the model | |||
| Returns: | |||
| name of the model created | |||
| @@ -79,6 +83,8 @@ class HubApi: | |||
| model_id = {owner}/{name} | |||
| </Tip> | |||
| """ | |||
| if model_id is None: | |||
| raise InvalidParameter('model_id is required!') | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| @@ -151,11 +157,33 @@ class HubApi: | |||
| else: | |||
| r.raise_for_status() | |||
| def _check_cookie(self, | |||
| use_cookies: Union[bool, | |||
| CookieJar] = False) -> CookieJar: | |||
| cookies = None | |||
| if isinstance(use_cookies, CookieJar): | |||
| cookies = use_cookies | |||
| elif use_cookies: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| return cookies | |||
| def get_model_branches_and_tags( | |||
| self, | |||
| model_id: str, | |||
| use_cookies: Union[bool, CookieJar] = False | |||
| ) -> Tuple[List[str], List[str]]: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| """Get model branch and tags. | |||
| Args: | |||
| model_id (str): The model id | |||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
| will load cookie from local. Defaults to False. | |||
| Returns: | |||
| Tuple[List[str], List[str]]: _description_ | |||
| """ | |||
| cookies = self._check_cookie(use_cookies) | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||
| r = requests.get(path, cookies=cookies) | |||
| @@ -169,23 +197,33 @@ class HubApi: | |||
| ] if info['RevisionMap']['Tags'] else [] | |||
| return branches, tags | |||
| def get_model_files( | |||
| self, | |||
| model_id: str, | |||
| revision: Optional[str] = 'master', | |||
| root: Optional[str] = None, | |||
| recursive: Optional[str] = False, | |||
| use_cookies: Union[bool, CookieJar] = False) -> List[dict]: | |||
| def get_model_files(self, | |||
| model_id: str, | |||
| revision: Optional[str] = 'master', | |||
| root: Optional[str] = None, | |||
| recursive: Optional[str] = False, | |||
| use_cookies: Union[bool, CookieJar] = False, | |||
| is_snapshot: Optional[bool] = True) -> List[dict]: | |||
| """List the models files. | |||
| cookies = None | |||
| if isinstance(use_cookies, CookieJar): | |||
| cookies = use_cookies | |||
| elif use_cookies: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| Args: | |||
| model_id (str): The model id | |||
| revision (Optional[str], optional): The branch or tag name. Defaults to 'master'. | |||
| root (Optional[str], optional): The root path. Defaults to None. | |||
| recursive (Optional[str], optional): Is recurive list files. Defaults to False. | |||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
| will load cookie from local. Defaults to False. | |||
| is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False. | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}' | |||
| Raises: | |||
| ValueError: If user_cookies is True, but no local cookie. | |||
| Returns: | |||
| List[dict]: Model file list. | |||
| """ | |||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % ( | |||
| self.endpoint, model_id, revision, recursive, is_snapshot) | |||
| cookies = self._check_cookie(use_cookies) | |||
| if root is not None: | |||
| path = path + f'&Root={root}' | |||
| @@ -10,6 +10,10 @@ class GitError(Exception): | |||
| pass | |||
| class InvalidParameter(Exception): | |||
| pass | |||
| def is_ok(rsp): | |||
| """ Check the request is ok | |||
| @@ -7,6 +7,7 @@ import tempfile | |||
| import time | |||
| from functools import partial | |||
| from hashlib import sha256 | |||
| from http.cookiejar import CookieJar | |||
| from pathlib import Path | |||
| from typing import BinaryIO, Dict, Optional, Union | |||
| from uuid import uuid4 | |||
| @@ -107,7 +108,9 @@ def model_file_download( | |||
| _api = HubApi() | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| file_to_download_info = None | |||
| is_commit_id = False | |||
| if revision in branches or revision in tags: # The revision is version or tag, | |||
| @@ -117,18 +120,19 @@ def model_file_download( | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| ) | |||
| use_cookies=False if cookies is None else cookies, | |||
| is_snapshot=False) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| if model_file['Path'] == file_path: | |||
| model_file['Branch'] = revision | |||
| if cache.exists(model_file): | |||
| return cache.get_file_by_info(model_file) | |||
| else: | |||
| file_to_download_info = model_file | |||
| break | |||
| if file_to_download_info is None: | |||
| raise NotExistError('The file path: %s not exist in: %s' % | |||
| @@ -141,8 +145,6 @@ def model_file_download( | |||
| return cached_file_path # the file is in cache. | |||
| is_commit_id = True | |||
| # we need to download again | |||
| # TODO: skip using JWT for authorization, use cookie instead | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| url_to_download = get_file_download_url(model_id, file_path, revision) | |||
| file_to_download_info = { | |||
| 'Path': file_path, | |||
| @@ -202,7 +204,7 @@ def http_get_file( | |||
| url: str, | |||
| local_dir: str, | |||
| file_name: str, | |||
| cookies: Dict[str, str], | |||
| cookies: CookieJar, | |||
| headers: Optional[Dict[str, str]] = None, | |||
| ): | |||
| """ | |||
| @@ -217,7 +219,7 @@ def http_get_file( | |||
| local directory where the downloaded file stores | |||
| file_name(`str`): | |||
| name of the file stored in `local_dir` | |||
| cookies(`Dict[str, str]`): | |||
| cookies(`CookieJar`): | |||
| cookies used to authentication the user, which is used for downloading private repos | |||
| headers(`Optional[Dict[str, str]] = None`): | |||
| http headers to carry necessary info when requesting the remote file | |||
| @@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| except GitError: | |||
| return False | |||
| def git_lfs_install(self, repo_dir): | |||
| cmd = ['git', '-C', repo_dir, 'lfs', 'install'] | |||
| try: | |||
| self._run_git_command(*cmd) | |||
| return True | |||
| except GitError: | |||
| return False | |||
| def clone(self, | |||
| repo_base_dir: str, | |||
| token: str, | |||
| @@ -1,7 +1,7 @@ | |||
| import os | |||
| from typing import List, Optional | |||
| from modelscope.hub.errors import GitError | |||
| from modelscope.hub.errors import GitError, InvalidParameter | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from .constants import MODELSCOPE_URL_SCHEME | |||
| @@ -49,6 +49,8 @@ class Repository: | |||
| git_wrapper = GitCommandWrapper() | |||
| if not git_wrapper.is_lfs_installed(): | |||
| logger.error('git lfs is not installed, please install.') | |||
| else: | |||
| git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | |||
| self.git_wrapper = GitCommandWrapper(git_path) | |||
| os.makedirs(self.model_dir, exist_ok=True) | |||
| @@ -74,8 +76,6 @@ class Repository: | |||
| def push(self, | |||
| commit_message: str, | |||
| files: List[str] = list(), | |||
| all_files: bool = False, | |||
| branch: Optional[str] = 'master', | |||
| force: bool = False): | |||
| """Push local to remote, this method will do. | |||
| @@ -86,8 +86,12 @@ class Repository: | |||
| commit_message (str): commit message | |||
| revision (Optional[str], optional): which branch to push. Defaults to 'master'. | |||
| """ | |||
| if commit_message is None: | |||
| msg = 'commit_message must be provided!' | |||
| raise InvalidParameter(msg) | |||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
| self.git_wrapper.add(self.model_dir, files, all_files) | |||
| self.git_wrapper.pull(self.model_dir) | |||
| self.git_wrapper.add(self.model_dir, all_files=True) | |||
| self.git_wrapper.commit(self.model_dir, commit_message) | |||
| self.git_wrapper.push( | |||
| repo_dir=self.model_dir, | |||
| @@ -20,8 +20,7 @@ def snapshot_download(model_id: str, | |||
| revision: Optional[str] = 'master', | |||
| cache_dir: Union[str, Path, None] = None, | |||
| user_agent: Optional[Union[Dict, str]] = None, | |||
| local_files_only: Optional[bool] = False, | |||
| private: Optional[bool] = False) -> str: | |||
| local_files_only: Optional[bool] = False) -> str: | |||
| """Download all files of a repo. | |||
| Downloads a whole snapshot of a repo's files at the specified revision. This | |||
| is useful when you want all files from a repo, because you don't know which | |||
| @@ -79,8 +78,10 @@ def snapshot_download(model_id: str, | |||
| # make headers | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| _api = HubApi() | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| # get file list from model repo | |||
| branches, tags = _api.get_model_branches_and_tags(model_id) | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| if revision not in branches and revision not in tags: | |||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||
| % revision) | |||
| @@ -89,11 +90,8 @@ def snapshot_download(model_id: str, | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| use_cookies=private) | |||
| cookies = None | |||
| if private: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| use_cookies=False if cookies is None else cookies, | |||
| is_snapshot=True) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| @@ -116,7 +114,7 @@ def snapshot_download(model_id: str, | |||
| local_dir=tempfile.gettempdir(), | |||
| file_name=model_file['Name'], | |||
| headers=headers, | |||
| cookies=None if cookies is None else cookies.get_dict()) | |||
| cookies=cookies) | |||
| # put file to cache | |||
| cache.put_file( | |||
| model_file, | |||
| @@ -101,8 +101,9 @@ class FileSystemCache(object): | |||
| Args: | |||
| key (dict): The cache key. | |||
| """ | |||
| self.cached_files.remove(key) | |||
| self.save_cached_files() | |||
| if key in self.cached_files: | |||
| self.cached_files.remove(key) | |||
| self.save_cached_files() | |||
| def exists(self, key): | |||
| for cache_file in self.cached_files: | |||
| @@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache): | |||
| return orig_path | |||
| else: | |||
| self.remove_key(cached_file) | |||
| break | |||
| return None | |||
| @@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache): | |||
| cached_key['Revision'].startswith(key['Revision']) | |||
| or key['Revision'].startswith(cached_key['Revision'])): | |||
| is_exists = True | |||
| break | |||
| file_path = os.path.join(self.cache_root_location, | |||
| model_file_info['Path']) | |||
| if is_exists: | |||
| @@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache): | |||
| cached_file['Path']) | |||
| if os.path.exists(file_path): | |||
| os.remove(file_path) | |||
| break | |||
| def put_file(self, model_file_info, model_file_location): | |||
| """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. | |||
| @@ -31,9 +31,10 @@ def create_model_if_not_exist( | |||
| else: | |||
| api.create_model( | |||
| model_id=model_id, | |||
| chinese_name=chinese_name, | |||
| visibility=visibility, | |||
| license=license) | |||
| license=license, | |||
| chinese_name=chinese_name, | |||
| ) | |||
| print(f'model {model_id} successfully created.') | |||
| return True | |||
| @@ -3,6 +3,7 @@ import os | |||
| import tempfile | |||
| import unittest | |||
| import uuid | |||
| from shutil import rmtree | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| @@ -23,7 +24,6 @@ download_model_file_name = 'test.bin' | |||
| class HubOperationTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| @@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase): | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| chinese_name=model_chinese_name, | |||
| visibility=ModelVisibility.PUBLIC, | |||
| license=Licenses.APACHE_V2) | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=model_chinese_name, | |||
| ) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| os.chdir(self.model_dir) | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, 'test.bin')) | |||
| repo.push('add model', all_files=True) | |||
| % os.path.join(self.model_dir, download_model_file_name)) | |||
| repo.push('add model') | |||
| def tearDown(self): | |||
| os.chdir(self.old_cwd) | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def test_model_repo_creation(self): | |||
| @@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase): | |||
| mdtime2 = os.path.getmtime(downloaded_file_path) | |||
| assert mdtime1 == mdtime2 | |||
| def test_download_public_without_login(self): | |||
| rmtree(ModelScopeConfig.path_credential) | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| cache_dir=temporary_dir) | |||
| assert os.path.exists(downloaded_file) | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| def test_snapshot_delete_download_cache_file(self): | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| os.remove(downloaded_file_path) | |||
| # download again in cache | |||
| file_download_path = model_file_download( | |||
| model_id=self.model_id, file_path='README.md') | |||
| assert os.path.exists(file_download_path) | |||
| # deleted file need download again | |||
| file_download_path = model_file_download( | |||
| model_id=self.model_id, file_path=download_model_file_name) | |||
| assert os.path.exists(file_download_path) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,85 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import uuid | |||
| from requests.exceptions import HTTPError | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.errors import GitError | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.utils.constant import ModelFile | |||
| USER_NAME = 'maasadmin' | |||
| PASSWORD = '12345678' | |||
| USER_NAME2 = 'sdkdev' | |||
| model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'unittest' | |||
| class HubPrivateFileDownloadTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=model_chinese_name, | |||
| ) | |||
| def tearDown(self): | |||
| os.chdir(self.old_cwd) | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def test_snapshot_download_private_model(self): | |||
| snapshot_path = snapshot_download(self.model_id) | |||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||
| def test_snapshot_download_private_model_no_permission(self): | |||
| self.token, _ = self.api.login(USER_NAME2, PASSWORD) | |||
| with self.assertRaises(HTTPError): | |||
| snapshot_download(self.model_id) | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| def test_download_file_private_model(self): | |||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||
| assert os.path.exists(file_path) | |||
| def test_download_file_private_model_no_permission(self): | |||
| self.token, _ = self.api.login(USER_NAME2, PASSWORD) | |||
| with self.assertRaises(HTTPError): | |||
| model_file_download(self.model_id, ModelFile.README) | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| def test_snapshot_download_local_only(self): | |||
| with self.assertRaises(ValueError): | |||
| snapshot_download(self.model_id, local_files_only=True) | |||
| snapshot_path = snapshot_download(self.model_id) | |||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||
| snapshot_path = snapshot_download(self.model_id, local_files_only=True) | |||
| assert os.path.exists(snapshot_path) | |||
| def test_file_download_local_only(self): | |||
| with self.assertRaises(ValueError): | |||
| model_file_download( | |||
| self.model_id, ModelFile.README, local_files_only=True) | |||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||
| assert os.path.exists(file_path) | |||
| file_path = model_file_download( | |||
| self.model_id, ModelFile.README, local_files_only=True) | |||
| assert os.path.exists(file_path) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -5,6 +5,7 @@ import unittest | |||
| import uuid | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.errors import GitError | |||
| from modelscope.hub.repository import Repository | |||
| @@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'unittest' | |||
| DEFAULT_GIT_PATH = 'git' | |||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
| download_model_file_name = 'mnist-12.onnx' | |||
| class HubPrivateRepositoryTest(unittest.TestCase): | |||
| @@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=model_chinese_name, | |||
| visibility=1, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| ) | |||
| def tearDown(self): | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| @@ -2,7 +2,6 @@ | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import time | |||
| import unittest | |||
| import uuid | |||
| from os.path import expanduser | |||
| @@ -10,6 +9,7 @@ from os.path import expanduser | |||
| from requests import delete | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.errors import NotExistError | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| @@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase): | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PUBLIC, # 1-private, 5-public | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=model_chinese_name, | |||
| visibility=5, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| ) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| @@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase): | |||
| os.chdir(self.model_dir) | |||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
| repo.push('test', all_files=True) | |||
| repo.push('test') | |||
| add1 = model_file_download(self.model_id, 'add1.py') | |||
| assert os.path.exists(add1) | |||
| add2 = model_file_download(self.model_id, 'add2.py') | |||
| assert os.path.exists(add2) | |||
| def test_push_files(self): | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
| os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||
| repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||
| add1 = model_file_download(self.model_id, 'add1.py') | |||
| assert os.path.exists(add1) | |||
| add2 = model_file_download(self.model_id, 'add2.py') | |||
| assert os.path.exists(add2) | |||
| with self.assertRaises(NotExistError) as cm: | |||
| model_file_download(self.model_id, 'add3.py') | |||
| print(cm.exception) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||