Browse Source

Merge remote-tracking branch 'origin' into nlp/space/dst

master
ly119399 3 years ago
parent
commit
94522aa721
7 changed files with 52 additions and 12 deletions
  1. +4
    -4
      modelscope/hub/api.py
  2. +1
    -1
      modelscope/hub/file_download.py
  3. +12
    -0
      modelscope/hub/git.py
  4. +7
    -4
      modelscope/hub/repository.py
  5. +1
    -1
      modelscope/hub/snapshot_download.py
  6. +17
    -0
      tests/hub/test_hub_operation.py
  7. +10
    -2
      tests/hub/test_hub_repository.py

+ 4
- 4
modelscope/hub/api.py View File

@@ -203,7 +203,7 @@ class HubApi:
root: Optional[str] = None, root: Optional[str] = None,
recursive: Optional[str] = False, recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False, use_cookies: Union[bool, CookieJar] = False,
is_snapshot: Optional[bool] = True) -> List[dict]:
headers: Optional[dict] = {}) -> List[dict]:
"""List the models files. """List the models files.


Args: Args:
@@ -221,13 +221,13 @@ class HubApi:
Returns: Returns:
List[dict]: Model file list. List[dict]: Model file list.
""" """
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % (
self.endpoint, model_id, revision, recursive, is_snapshot)
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % (
self.endpoint, model_id, revision, recursive)
cookies = self._check_cookie(use_cookies) cookies = self._check_cookie(use_cookies)
if root is not None: if root is not None:
path = path + f'&Root={root}' path = path + f'&Root={root}'


r = requests.get(path, cookies=cookies)
r = requests.get(path, cookies=cookies, headers=headers)


r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()


+ 1
- 1
modelscope/hub/file_download.py View File

@@ -121,7 +121,7 @@ def model_file_download(
revision=revision, revision=revision,
recursive=True, recursive=True,
use_cookies=False if cookies is None else cookies, use_cookies=False if cookies is None else cookies,
is_snapshot=False)
)


for model_file in model_files: for model_file in model_files:
if model_file['Type'] == 'tree': if model_file['Type'] == 'tree':


+ 12
- 0
modelscope/hub/git.py View File

@@ -1,3 +1,4 @@
import os
import subprocess import subprocess
from typing import List from typing import List
from xmlrpc.client import Boolean from xmlrpc.client import Boolean
@@ -167,3 +168,14 @@ class GitCommandWrapper(metaclass=Singleton):
rsp = self._run_git_command(*cmd_args) rsp = self._run_git_command(*cmd_args)
url = rsp.stdout.decode('utf8') url = rsp.stdout.decode('utf8')
return url.strip() return url.strip()

def list_lfs_files(self, repo_dir: str):
cmd_args = '-C %s lfs ls-files' % repo_dir
cmd_args = cmd_args.split(' ')
rsp = self._run_git_command(*cmd_args)
out = rsp.stdout.decode('utf8').strip()
files = []
for line in out.split(os.linesep):
files.append(line.split(' ')[-1])

return files

+ 7
- 4
modelscope/hub/repository.py View File

@@ -49,8 +49,6 @@ class Repository:
git_wrapper = GitCommandWrapper() git_wrapper = GitCommandWrapper()
if not git_wrapper.is_lfs_installed(): if not git_wrapper.is_lfs_installed():
logger.error('git lfs is not installed, please install.') logger.error('git lfs is not installed, please install.')
else:
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs


self.git_wrapper = GitCommandWrapper(git_path) self.git_wrapper = GitCommandWrapper(git_path)
os.makedirs(self.model_dir, exist_ok=True) os.makedirs(self.model_dir, exist_ok=True)
@@ -63,8 +61,11 @@ class Repository:
self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, self.git_wrapper.clone(self.model_base_dir, self.auth_token, url,
self.model_repo_name, revision) self.model_repo_name, revision)


if git_wrapper.is_lfs_installed():
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs

def _get_model_id_url(self, model_id): def _get_model_id_url(self, model_id):
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}'
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}.git'
return url return url


def _get_remote_url(self): def _get_remote_url(self):
@@ -86,9 +87,11 @@ class Repository:
commit_message (str): commit message commit_message (str): commit message
revision (Optional[str], optional): which branch to push. Defaults to 'master'. revision (Optional[str], optional): which branch to push. Defaults to 'master'.
""" """
if commit_message is None:
if commit_message is None or not isinstance(commit_message, str):
msg = 'commit_message must be provided!' msg = 'commit_message must be provided!'
raise InvalidParameter(msg) raise InvalidParameter(msg)
if not isinstance(force, bool):
raise InvalidParameter('force must be bool')
url = self.git_wrapper.get_repo_remote_url(self.model_dir) url = self.git_wrapper.get_repo_remote_url(self.model_dir)
self.git_wrapper.pull(self.model_dir) self.git_wrapper.pull(self.model_dir)
self.git_wrapper.add(self.model_dir, all_files=True) self.git_wrapper.add(self.model_dir, all_files=True)


+ 1
- 1
modelscope/hub/snapshot_download.py View File

@@ -91,7 +91,7 @@ def snapshot_download(model_id: str,
revision=revision, revision=revision,
recursive=True, recursive=True,
use_cookies=False if cookies is None else cookies, use_cookies=False if cookies is None else cookies,
is_snapshot=True)
headers={'Snapshot': 'True'})


for model_file in model_files: for model_file in model_files:
if model_file['Type'] == 'tree': if model_file['Type'] == 'tree':


+ 17
- 0
tests/hub/test_hub_operation.py View File

@@ -5,6 +5,8 @@ import unittest
import uuid import uuid
from shutil import rmtree from shutil import rmtree


import requests

from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download from modelscope.hub.file_download import model_file_download
@@ -77,6 +79,11 @@ class HubOperationTest(unittest.TestCase):
snapshot_path = snapshot_download(model_id=self.model_id) snapshot_path = snapshot_download(model_id=self.model_id)
mdtime2 = os.path.getmtime(downloaded_file_path) mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2 assert mdtime1 == mdtime2
model_file_download(
model_id=self.model_id,
file_path=download_model_file_name) # not add counter
download_times = self.get_model_download_times()
assert download_times == 2


def test_download_public_without_login(self): def test_download_public_without_login(self):
rmtree(ModelScopeConfig.path_credential) rmtree(ModelScopeConfig.path_credential)
@@ -107,6 +114,16 @@ class HubOperationTest(unittest.TestCase):
model_id=self.model_id, file_path=download_model_file_name) model_id=self.model_id, file_path=download_model_file_name)
assert os.path.exists(file_download_path) assert os.path.exists(file_download_path)


def get_model_download_times(self):
url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads'
cookies = ModelScopeConfig.get_cookies()
r = requests.get(url, cookies=cookies)
if r.status_code == 200:
return r.json()['Data']['Downloads']
else:
r.raise_for_status()
return None



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

+ 10
- 2
tests/hub/test_hub_repository.py View File

@@ -12,6 +12,7 @@ from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import NotExistError from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download from modelscope.hub.file_download import model_file_download
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -24,8 +25,6 @@ model_chinese_name = '达摩卡通化模型'
model_org = 'unittest' model_org = 'unittest'
DEFAULT_GIT_PATH = 'git' DEFAULT_GIT_PATH = 'git'


download_model_file_name = 'mnist-12.onnx'



def delete_credential(): def delete_credential():
path_credential = expanduser('~/.modelscope/credentials') path_credential = expanduser('~/.modelscope/credentials')
@@ -80,13 +79,22 @@ class HubRepositoryTest(unittest.TestCase):
repo = Repository(self.model_dir, clone_from=self.model_id) repo = Repository(self.model_dir, clone_from=self.model_id)
assert os.path.exists(os.path.join(self.model_dir, 'README.md')) assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
os.chdir(self.model_dir) os.chdir(self.model_dir)
lfs_file1 = 'test1.bin'
lfs_file2 = 'test2.bin'
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1))
os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2))
repo.push('test') repo.push('test')
add1 = model_file_download(self.model_id, 'add1.py') add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1) assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py') add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2) assert os.path.exists(add2)
# check lfs files.
git_wrapper = GitCommandWrapper()
lfs_files = git_wrapper.list_lfs_files(self.model_dir)
assert lfs_file1 in lfs_files
assert lfs_file2 in lfs_files




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save