| @@ -203,7 +203,7 @@ class HubApi: | |||||
| root: Optional[str] = None, | root: Optional[str] = None, | ||||
| recursive: Optional[str] = False, | recursive: Optional[str] = False, | ||||
| use_cookies: Union[bool, CookieJar] = False, | use_cookies: Union[bool, CookieJar] = False, | ||||
| is_snapshot: Optional[bool] = True) -> List[dict]: | |||||
| headers: Optional[dict] = {}) -> List[dict]: | |||||
| """List the models files. | """List the models files. | ||||
| Args: | Args: | ||||
| @@ -221,13 +221,13 @@ class HubApi: | |||||
| Returns: | Returns: | ||||
| List[dict]: Model file list. | List[dict]: Model file list. | ||||
| """ | """ | ||||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % ( | |||||
| self.endpoint, model_id, revision, recursive, is_snapshot) | |||||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( | |||||
| self.endpoint, model_id, revision, recursive) | |||||
| cookies = self._check_cookie(use_cookies) | cookies = self._check_cookie(use_cookies) | ||||
| if root is not None: | if root is not None: | ||||
| path = path + f'&Root={root}' | path = path + f'&Root={root}' | ||||
| r = requests.get(path, cookies=cookies) | |||||
| r = requests.get(path, cookies=cookies, headers=headers) | |||||
| r.raise_for_status() | r.raise_for_status() | ||||
| d = r.json() | d = r.json() | ||||
| @@ -121,7 +121,7 @@ def model_file_download( | |||||
| revision=revision, | revision=revision, | ||||
| recursive=True, | recursive=True, | ||||
| use_cookies=False if cookies is None else cookies, | use_cookies=False if cookies is None else cookies, | ||||
| is_snapshot=False) | |||||
| ) | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| @@ -1,3 +1,4 @@ | |||||
| import os | |||||
| import subprocess | import subprocess | ||||
| from typing import List | from typing import List | ||||
| from xmlrpc.client import Boolean | from xmlrpc.client import Boolean | ||||
| @@ -167,3 +168,14 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| rsp = self._run_git_command(*cmd_args) | rsp = self._run_git_command(*cmd_args) | ||||
| url = rsp.stdout.decode('utf8') | url = rsp.stdout.decode('utf8') | ||||
| return url.strip() | return url.strip() | ||||
| def list_lfs_files(self, repo_dir: str): | |||||
| cmd_args = '-C %s lfs ls-files' % repo_dir | |||||
| cmd_args = cmd_args.split(' ') | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| out = rsp.stdout.decode('utf8').strip() | |||||
| files = [] | |||||
| for line in out.split(os.linesep): | |||||
| files.append(line.split(' ')[-1]) | |||||
| return files | |||||
| @@ -49,8 +49,6 @@ class Repository: | |||||
| git_wrapper = GitCommandWrapper() | git_wrapper = GitCommandWrapper() | ||||
| if not git_wrapper.is_lfs_installed(): | if not git_wrapper.is_lfs_installed(): | ||||
| logger.error('git lfs is not installed, please install.') | logger.error('git lfs is not installed, please install.') | ||||
| else: | |||||
| git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | |||||
| self.git_wrapper = GitCommandWrapper(git_path) | self.git_wrapper = GitCommandWrapper(git_path) | ||||
| os.makedirs(self.model_dir, exist_ok=True) | os.makedirs(self.model_dir, exist_ok=True) | ||||
| @@ -63,8 +61,11 @@ class Repository: | |||||
| self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | ||||
| self.model_repo_name, revision) | self.model_repo_name, revision) | ||||
| if git_wrapper.is_lfs_installed(): | |||||
| git_wrapper.git_lfs_install(self.model_dir) # init repo lfs | |||||
| def _get_model_id_url(self, model_id): | def _get_model_id_url(self, model_id): | ||||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}.git' | |||||
| return url | return url | ||||
| def _get_remote_url(self): | def _get_remote_url(self): | ||||
| @@ -86,9 +87,11 @@ class Repository: | |||||
| commit_message (str): commit message | commit_message (str): commit message | ||||
| revision (Optional[str], optional): which branch to push. Defaults to 'master'. | revision (Optional[str], optional): which branch to push. Defaults to 'master'. | ||||
| """ | """ | ||||
| if commit_message is None: | |||||
| if commit_message is None or not isinstance(commit_message, str): | |||||
| msg = 'commit_message must be provided!' | msg = 'commit_message must be provided!' | ||||
| raise InvalidParameter(msg) | raise InvalidParameter(msg) | ||||
| if not isinstance(force, bool): | |||||
| raise InvalidParameter('force must be bool') | |||||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | url = self.git_wrapper.get_repo_remote_url(self.model_dir) | ||||
| self.git_wrapper.pull(self.model_dir) | self.git_wrapper.pull(self.model_dir) | ||||
| self.git_wrapper.add(self.model_dir, all_files=True) | self.git_wrapper.add(self.model_dir, all_files=True) | ||||
| @@ -91,7 +91,7 @@ def snapshot_download(model_id: str, | |||||
| revision=revision, | revision=revision, | ||||
| recursive=True, | recursive=True, | ||||
| use_cookies=False if cookies is None else cookies, | use_cookies=False if cookies is None else cookies, | ||||
| is_snapshot=True) | |||||
| headers={'Snapshot': 'True'}) | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| @@ -44,7 +44,7 @@ class DialogIntentPredictionPipeline(Pipeline): | |||||
| pos = np.where(pred == np.max(pred)) | pos = np.where(pred == np.max(pred)) | ||||
| result = { | result = { | ||||
| 'pred': pred, | |||||
| 'prediction': pred, | |||||
| 'label_pos': pos[0], | 'label_pos': pos[0], | ||||
| 'label': self.categories[pos[0][0]] | 'label': self.categories[pos[0][0]] | ||||
| } | } | ||||
| @@ -43,6 +43,6 @@ class DialogModelingPipeline(Pipeline): | |||||
| assert len(sys_rsp) > 2 | assert len(sys_rsp) > 2 | ||||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | ||||
| inputs['sys'] = sys_rsp | |||||
| inputs['response'] = sys_rsp | |||||
| return inputs | return inputs | ||||
| @@ -142,10 +142,10 @@ TASK_OUTPUTS = { | |||||
| # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | ||||
| # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | ||||
| # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | ||||
| Tasks.dialog_intent_prediction: ['pred', 'label_pos', 'label'], | |||||
| Tasks.dialog_intent_prediction: ['prediction', 'label_pos', 'label'], | |||||
| # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | ||||
| Tasks.dialog_modeling: ['sys'], | |||||
| Tasks.dialog_modeling: ['response'], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| @@ -5,6 +5,8 @@ import unittest | |||||
| import uuid | import uuid | ||||
| from shutil import rmtree | from shutil import rmtree | ||||
| import requests | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | from modelscope.hub.api import HubApi, ModelScopeConfig | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| @@ -77,6 +79,11 @@ class HubOperationTest(unittest.TestCase): | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | snapshot_path = snapshot_download(model_id=self.model_id) | ||||
| mdtime2 = os.path.getmtime(downloaded_file_path) | mdtime2 = os.path.getmtime(downloaded_file_path) | ||||
| assert mdtime1 == mdtime2 | assert mdtime1 == mdtime2 | ||||
| model_file_download( | |||||
| model_id=self.model_id, | |||||
| file_path=download_model_file_name) # not add counter | |||||
| download_times = self.get_model_download_times() | |||||
| assert download_times == 2 | |||||
| def test_download_public_without_login(self): | def test_download_public_without_login(self): | ||||
| rmtree(ModelScopeConfig.path_credential) | rmtree(ModelScopeConfig.path_credential) | ||||
| @@ -107,6 +114,16 @@ class HubOperationTest(unittest.TestCase): | |||||
| model_id=self.model_id, file_path=download_model_file_name) | model_id=self.model_id, file_path=download_model_file_name) | ||||
| assert os.path.exists(file_download_path) | assert os.path.exists(file_download_path) | ||||
| def get_model_download_times(self): | |||||
| url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| r = requests.get(url, cookies=cookies) | |||||
| if r.status_code == 200: | |||||
| return r.json()['Data']['Downloads'] | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -12,6 +12,7 @@ from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
| from modelscope.hub.errors import NotExistError | from modelscope.hub.errors import NotExistError | ||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -24,8 +25,6 @@ model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | model_org = 'unittest' | ||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| def delete_credential(): | def delete_credential(): | ||||
| path_credential = expanduser('~/.modelscope/credentials') | path_credential = expanduser('~/.modelscope/credentials') | ||||
| @@ -80,13 +79,22 @@ class HubRepositoryTest(unittest.TestCase): | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | repo = Repository(self.model_dir, clone_from=self.model_id) | ||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | ||||
| os.chdir(self.model_dir) | os.chdir(self.model_dir) | ||||
| lfs_file1 = 'test1.bin' | |||||
| lfs_file2 = 'test2.bin' | |||||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | ||||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | ||||
| os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) | |||||
| os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) | |||||
| repo.push('test') | repo.push('test') | ||||
| add1 = model_file_download(self.model_id, 'add1.py') | add1 = model_file_download(self.model_id, 'add1.py') | ||||
| assert os.path.exists(add1) | assert os.path.exists(add1) | ||||
| add2 = model_file_download(self.model_id, 'add2.py') | add2 = model_file_download(self.model_id, 'add2.py') | ||||
| assert os.path.exists(add2) | assert os.path.exists(add2) | ||||
| # check lfs files. | |||||
| git_wrapper = GitCommandWrapper() | |||||
| lfs_files = git_wrapper.list_lfs_files(self.model_dir) | |||||
| assert lfs_file1 in lfs_files | |||||
| assert lfs_file2 in lfs_files | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -116,7 +116,7 @@ class DialogModelingTest(unittest.TestCase): | |||||
| 'user_input': user, | 'user_input': user, | ||||
| 'history': result | 'history': result | ||||
| }) | }) | ||||
| print('sys : {}'.format(result['sys'])) | |||||
| print('response : {}'.format(result['response'])) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| @@ -140,7 +140,7 @@ class DialogModelingTest(unittest.TestCase): | |||||
| 'user_input': user, | 'user_input': user, | ||||
| 'history': result | 'history': result | ||||
| }) | }) | ||||
| print('sys : {}'.format(result['sys'])) | |||||
| print('response : {}'.format(result['response'])) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||