mulin.lyh 3 years ago
parent
commit
0d17eb5b39
12 changed files with 235 additions and 78 deletions
  1. +61
    -23
      modelscope/hub/api.py
  2. +4
    -0
      modelscope/hub/errors.py
  3. +9
    -7
      modelscope/hub/file_download.py
  4. +8
    -0
      modelscope/hub/git.py
  5. +8
    -4
      modelscope/hub/repository.py
  6. +7
    -9
      modelscope/hub/snapshot_download.py
  7. +6
    -2
      modelscope/hub/utils/caching.py
  8. +3
    -2
      modelscope/utils/hub.py
  9. +35
    -7
      tests/hub/test_hub_operation.py
  10. +85
    -0
      tests/hub/test_hub_private_files.py
  11. +4
    -5
      tests/hub/test_hub_private_repository.py
  12. +5
    -19
      tests/hub/test_hub_repository.py

+ 61
- 23
modelscope/hub/api.py View File

@@ -9,7 +9,7 @@ import requests


from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .constants import MODELSCOPE_URL_SCHEME from .constants import MODELSCOPE_URL_SCHEME
from .errors import NotExistError, is_ok, raise_on_error
from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error
from .utils.utils import (get_endpoint, get_gitlab_domain, from .utils.utils import (get_endpoint, get_gitlab_domain,
model_id_to_group_owner_name) model_id_to_group_owner_name)


@@ -61,17 +61,21 @@ class HubApi:


return d['Data']['AccessToken'], cookies return d['Data']['AccessToken'], cookies


def create_model(self, model_id: str, chinese_name: str, visibility: int,
license: str) -> str:
def create_model(
self,
model_id: str,
visibility: str,
license: str,
chinese_name: Optional[str] = None,
) -> str:
""" """
Create model repo at ModelScopeHub Create model repo at ModelScopeHub


Args: Args:
model_id:(`str`): The model id model_id:(`str`): The model id
chinese_name(`str`): chinese name of the model
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
license(`str`): license of the model, candidates can be found at: TBA

visibility(`int`): visibility of the model(1-private, 5-public), default public.
license(`str`): license of the model, default none.
chinese_name(`str`, *optional*): chinese name of the model
Returns: Returns:
name of the model created name of the model created


@@ -79,6 +83,8 @@ class HubApi:
model_id = {owner}/{name} model_id = {owner}/{name}
</Tip> </Tip>
""" """
if model_id is None:
raise InvalidParameter('model_id is required!')
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if cookies is None: if cookies is None:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
@@ -151,11 +157,33 @@ class HubApi:
else: else:
r.raise_for_status() r.raise_for_status()


def _check_cookie(self,
use_cookies: Union[bool,
CookieJar] = False) -> CookieJar:
cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
return cookies

def get_model_branches_and_tags( def get_model_branches_and_tags(
self, self,
model_id: str, model_id: str,
use_cookies: Union[bool, CookieJar] = False
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
cookies = ModelScopeConfig.get_cookies()
"""Get model branch and tags.

Args:
model_id (str): The model id
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
Returns:
Tuple[List[str], List[str]]: _description_
"""
cookies = self._check_cookie(use_cookies)


path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies) r = requests.get(path, cookies=cookies)
@@ -169,23 +197,33 @@ class HubApi:
] if info['RevisionMap']['Tags'] else [] ] if info['RevisionMap']['Tags'] else []
return branches, tags return branches, tags


def get_model_files(
self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
def get_model_files(self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False,
is_snapshot: Optional[bool] = True) -> List[dict]:
"""List the models files.


cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
Args:
model_id (str): The model id
revision (Optional[str], optional): The branch or tag name. Defaults to 'master'.
root (Optional[str], optional): The root path. Defaults to None.
recursive (Optional[str], optional): Is recurive list files. Defaults to False.
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False.


path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
Raises:
ValueError: If user_cookies is True, but no local cookie.

Returns:
List[dict]: Model file list.
"""
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % (
self.endpoint, model_id, revision, recursive, is_snapshot)
cookies = self._check_cookie(use_cookies)
if root is not None: if root is not None:
path = path + f'&Root={root}' path = path + f'&Root={root}'




+ 4
- 0
modelscope/hub/errors.py View File

@@ -10,6 +10,10 @@ class GitError(Exception):
pass pass




class InvalidParameter(Exception):
pass


def is_ok(rsp): def is_ok(rsp):
""" Check the request is ok """ Check the request is ok




+ 9
- 7
modelscope/hub/file_download.py View File

@@ -7,6 +7,7 @@ import tempfile
import time import time
from functools import partial from functools import partial
from hashlib import sha256 from hashlib import sha256
from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import BinaryIO, Dict, Optional, Union from typing import BinaryIO, Dict, Optional, Union
from uuid import uuid4 from uuid import uuid4
@@ -107,7 +108,9 @@ def model_file_download(


_api = HubApi() _api = HubApi()
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
branches, tags = _api.get_model_branches_and_tags(model_id)
cookies = ModelScopeConfig.get_cookies()
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
file_to_download_info = None file_to_download_info = None
is_commit_id = False is_commit_id = False
if revision in branches or revision in tags: # The revision is version or tag, if revision in branches or revision in tags: # The revision is version or tag,
@@ -117,18 +120,19 @@ def model_file_download(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
recursive=True, recursive=True,
)
use_cookies=False if cookies is None else cookies,
is_snapshot=False)


for model_file in model_files: for model_file in model_files:
if model_file['Type'] == 'tree': if model_file['Type'] == 'tree':
continue continue


if model_file['Path'] == file_path: if model_file['Path'] == file_path:
model_file['Branch'] = revision
if cache.exists(model_file): if cache.exists(model_file):
return cache.get_file_by_info(model_file) return cache.get_file_by_info(model_file)
else: else:
file_to_download_info = model_file file_to_download_info = model_file
break


if file_to_download_info is None: if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' % raise NotExistError('The file path: %s not exist in: %s' %
@@ -141,8 +145,6 @@ def model_file_download(
return cached_file_path # the file is in cache. return cached_file_path # the file is in cache.
is_commit_id = True is_commit_id = True
# we need to download again # we need to download again
# TODO: skip using JWT for authorization, use cookie instead
cookies = ModelScopeConfig.get_cookies()
url_to_download = get_file_download_url(model_id, file_path, revision) url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = { file_to_download_info = {
'Path': file_path, 'Path': file_path,
@@ -202,7 +204,7 @@ def http_get_file(
url: str, url: str,
local_dir: str, local_dir: str,
file_name: str, file_name: str,
cookies: Dict[str, str],
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
): ):
""" """
@@ -217,7 +219,7 @@ def http_get_file(
local directory where the downloaded file stores local directory where the downloaded file stores
file_name(`str`): file_name(`str`):
name of the file stored in `local_dir` name of the file stored in `local_dir`
cookies(`Dict[str, str]`):
cookies(`CookieJar`):
cookies used to authentication the user, which is used for downloading private repos cookies used to authentication the user, which is used for downloading private repos
headers(`Optional[Dict[str, str]] = None`): headers(`Optional[Dict[str, str]] = None`):
http headers to carry necessary info when requesting the remote file http headers to carry necessary info when requesting the remote file


+ 8
- 0
modelscope/hub/git.py View File

@@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton):
except GitError: except GitError:
return False return False


def git_lfs_install(self, repo_dir):
cmd = ['git', '-C', repo_dir, 'lfs', 'install']
try:
self._run_git_command(*cmd)
return True
except GitError:
return False

def clone(self, def clone(self,
repo_base_dir: str, repo_base_dir: str,
token: str, token: str,


+ 8
- 4
modelscope/hub/repository.py View File

@@ -1,7 +1,7 @@
import os import os
from typing import List, Optional from typing import List, Optional


from modelscope.hub.errors import GitError
from modelscope.hub.errors import GitError, InvalidParameter
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .api import ModelScopeConfig from .api import ModelScopeConfig
from .constants import MODELSCOPE_URL_SCHEME from .constants import MODELSCOPE_URL_SCHEME
@@ -49,6 +49,8 @@ class Repository:
git_wrapper = GitCommandWrapper() git_wrapper = GitCommandWrapper()
if not git_wrapper.is_lfs_installed(): if not git_wrapper.is_lfs_installed():
logger.error('git lfs is not installed, please install.') logger.error('git lfs is not installed, please install.')
else:
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs


self.git_wrapper = GitCommandWrapper(git_path) self.git_wrapper = GitCommandWrapper(git_path)
os.makedirs(self.model_dir, exist_ok=True) os.makedirs(self.model_dir, exist_ok=True)
@@ -74,8 +76,6 @@ class Repository:


def push(self, def push(self,
commit_message: str, commit_message: str,
files: List[str] = list(),
all_files: bool = False,
branch: Optional[str] = 'master', branch: Optional[str] = 'master',
force: bool = False): force: bool = False):
"""Push local to remote, this method will do. """Push local to remote, this method will do.
@@ -86,8 +86,12 @@ class Repository:
commit_message (str): commit message commit_message (str): commit message
revision (Optional[str], optional): which branch to push. Defaults to 'master'. revision (Optional[str], optional): which branch to push. Defaults to 'master'.
""" """
if commit_message is None:
msg = 'commit_message must be provided!'
raise InvalidParameter(msg)
url = self.git_wrapper.get_repo_remote_url(self.model_dir) url = self.git_wrapper.get_repo_remote_url(self.model_dir)
self.git_wrapper.add(self.model_dir, files, all_files)
self.git_wrapper.pull(self.model_dir)
self.git_wrapper.add(self.model_dir, all_files=True)
self.git_wrapper.commit(self.model_dir, commit_message) self.git_wrapper.commit(self.model_dir, commit_message)
self.git_wrapper.push( self.git_wrapper.push(
repo_dir=self.model_dir, repo_dir=self.model_dir,


+ 7
- 9
modelscope/hub/snapshot_download.py View File

@@ -20,8 +20,7 @@ def snapshot_download(model_id: str,
revision: Optional[str] = 'master', revision: Optional[str] = 'master',
cache_dir: Union[str, Path, None] = None, cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None, user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
private: Optional[bool] = False) -> str:
local_files_only: Optional[bool] = False) -> str:
"""Download all files of a repo. """Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which is useful when you want all files from a repo, because you don't know which
@@ -79,8 +78,10 @@ def snapshot_download(model_id: str,
# make headers # make headers
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
_api = HubApi() _api = HubApi()
cookies = ModelScopeConfig.get_cookies()
# get file list from model repo # get file list from model repo
branches, tags = _api.get_model_branches_and_tags(model_id)
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
if revision not in branches and revision not in tags: if revision not in branches and revision not in tags:
raise NotExistError('The specified branch or tag : %s not exist!' raise NotExistError('The specified branch or tag : %s not exist!'
% revision) % revision)
@@ -89,11 +90,8 @@ def snapshot_download(model_id: str,
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
recursive=True, recursive=True,
use_cookies=private)

cookies = None
if private:
cookies = ModelScopeConfig.get_cookies()
use_cookies=False if cookies is None else cookies,
is_snapshot=True)


for model_file in model_files: for model_file in model_files:
if model_file['Type'] == 'tree': if model_file['Type'] == 'tree':
@@ -116,7 +114,7 @@ def snapshot_download(model_id: str,
local_dir=tempfile.gettempdir(), local_dir=tempfile.gettempdir(),
file_name=model_file['Name'], file_name=model_file['Name'],
headers=headers, headers=headers,
cookies=None if cookies is None else cookies.get_dict())
cookies=cookies)
# put file to cache # put file to cache
cache.put_file( cache.put_file(
model_file, model_file,


+ 6
- 2
modelscope/hub/utils/caching.py View File

@@ -101,8 +101,9 @@ class FileSystemCache(object):
Args: Args:
key (dict): The cache key. key (dict): The cache key.
""" """
self.cached_files.remove(key)
self.save_cached_files()
if key in self.cached_files:
self.cached_files.remove(key)
self.save_cached_files()


def exists(self, key): def exists(self, key):
for cache_file in self.cached_files: for cache_file in self.cached_files:
@@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache):
return orig_path return orig_path
else: else:
self.remove_key(cached_file) self.remove_key(cached_file)
break


return None return None


@@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_key['Revision'].startswith(key['Revision']) cached_key['Revision'].startswith(key['Revision'])
or key['Revision'].startswith(cached_key['Revision'])): or key['Revision'].startswith(cached_key['Revision'])):
is_exists = True is_exists = True
break
file_path = os.path.join(self.cache_root_location, file_path = os.path.join(self.cache_root_location,
model_file_info['Path']) model_file_info['Path'])
if is_exists: if is_exists:
@@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_file['Path']) cached_file['Path'])
if os.path.exists(file_path): if os.path.exists(file_path):
os.remove(file_path) os.remove(file_path)
break


def put_file(self, model_file_info, model_file_location): def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache. """Put model on model_file_location to cache, the model first download to /tmp, and move to cache.


+ 3
- 2
modelscope/utils/hub.py View File

@@ -31,9 +31,10 @@ def create_model_if_not_exist(
else: else:
api.create_model( api.create_model(
model_id=model_id, model_id=model_id,
chinese_name=chinese_name,
visibility=visibility, visibility=visibility,
license=license)
license=license,
chinese_name=chinese_name,
)
print(f'model {model_id} successfully created.') print(f'model {model_id} successfully created.')
return True return True




+ 35
- 7
tests/hub/test_hub_operation.py View File

@@ -3,6 +3,7 @@ import os
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
from shutil import rmtree


from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.constants import Licenses, ModelVisibility
@@ -23,7 +24,6 @@ download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase): class HubOperationTest(unittest.TestCase):


def setUp(self): def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi() self.api = HubApi()
# note this is temporary before official account management is ready # note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD) self.api.login(USER_NAME, PASSWORD)
@@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name) self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
chinese_name=model_chinese_name,
visibility=ModelVisibility.PUBLIC, visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)
temporary_dir = tempfile.mkdtemp() temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name) self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id) repo = Repository(self.model_dir, clone_from=self.model_id)
os.chdir(self.model_dir)
os.system("echo 'testtest'>%s" os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, 'test.bin'))
repo.push('add model', all_files=True)
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')


def tearDown(self): def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id) self.api.delete_model(model_id=self.model_id)


def test_model_repo_creation(self): def test_model_repo_creation(self):
@@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase):
mdtime2 = os.path.getmtime(downloaded_file_path) mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2 assert mdtime1 == mdtime2


def test_download_public_without_login(self):
rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
temporary_dir = tempfile.mkdtemp()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
cache_dir=temporary_dir)
assert os.path.exists(downloaded_file)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_delete_download_cache_file(self):
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path)
# download again in cache
file_download_path = model_file_download(
model_id=self.model_id, file_path='README.md')
assert os.path.exists(file_download_path)
# deleted file need download again
file_download_path = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
assert os.path.exists(file_download_path)



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

+ 85
- 0
tests/hub/test_hub_private_files.py View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import unittest
import uuid

from requests.exceptions import HTTPError

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile

USER_NAME = 'maasadmin'
PASSWORD = '12345678'
USER_NAME2 = 'sdkdev'

model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'


class HubPrivateFileDownloadTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.token, _ = self.api.login(USER_NAME, PASSWORD)
self.model_name = uuid.uuid4().hex
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_snapshot_download_private_model(self):
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))

def test_snapshot_download_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
snapshot_download(self.model_id)
self.api.login(USER_NAME, PASSWORD)

def test_download_file_private_model(self):
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)

def test_download_file_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
model_file_download(self.model_id, ModelFile.README)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_download_local_only(self):
with self.assertRaises(ValueError):
snapshot_download(self.model_id, local_files_only=True)
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))
snapshot_path = snapshot_download(self.model_id, local_files_only=True)
assert os.path.exists(snapshot_path)

def test_file_download_local_only(self):
with self.assertRaises(ValueError):
model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
assert os.path.exists(file_path)


if __name__ == '__main__':
unittest.main()

+ 4
- 5
tests/hub/test_hub_private_repository.py View File

@@ -5,6 +5,7 @@ import unittest
import uuid import uuid


from modelscope.hub.api import HubApi from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError from modelscope.hub.errors import GitError
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository


@@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型'
model_org = 'unittest' model_org = 'unittest'
DEFAULT_GIT_PATH = 'git' DEFAULT_GIT_PATH = 'git'


sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
download_model_file_name = 'mnist-12.onnx'



class HubPrivateRepositoryTest(unittest.TestCase): class HubPrivateRepositoryTest(unittest.TestCase):


@@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name) self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name, chinese_name=model_chinese_name,
visibility=1, # 1-private, 5-public
license='apache-2.0')
)


def tearDown(self): def tearDown(self):
self.api.login(USER_NAME, PASSWORD) self.api.login(USER_NAME, PASSWORD)


+ 5
- 19
tests/hub/test_hub_repository.py View File

@@ -2,7 +2,6 @@
import os import os
import shutil import shutil
import tempfile import tempfile
import time
import unittest import unittest
import uuid import uuid
from os.path import expanduser from os.path import expanduser
@@ -10,6 +9,7 @@ from os.path import expanduser
from requests import delete from requests import delete


from modelscope.hub.api import HubApi from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import NotExistError from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
@@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name) self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PUBLIC, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name, chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
)
temporary_dir = tempfile.mkdtemp() temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name) self.model_dir = os.path.join(temporary_dir, self.model_name)


@@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase):
os.chdir(self.model_dir) os.chdir(self.model_dir)
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
repo.push('test', all_files=True)
repo.push('test')
add1 = model_file_download(self.model_id, 'add1.py') add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1) assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py') add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2) assert os.path.exists(add2)


def test_push_files(self):
repo = Repository(self.model_dir, clone_from=self.model_id)
assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py'))
repo.push('test', files=['add1.py', 'add2.py'], all_files=False)
add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2)
with self.assertRaises(NotExistError) as cm:
model_file_download(self.model_id, 'add3.py')
print(cm.exception)



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save