diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py
index 104eafbd..f4f31280 100644
--- a/modelscope/hub/api.py
+++ b/modelscope/hub/api.py
@@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union
import requests
from modelscope.utils.logger import get_logger
-from .constants import LOGGER_NAME
+from .constants import MODELSCOPE_URL_SCHEME
from .errors import NotExistError, is_ok, raise_on_error
-from .utils.utils import get_endpoint, model_id_to_group_owner_name
+from .utils.utils import (get_endpoint, get_gitlab_domain,
+ model_id_to_group_owner_name)
logger = get_logger()
@@ -40,9 +41,6 @@ class HubApi:
You only have to login once within 30 days.
-
- TODO: handle cookies expire
-
"""
path = f'{self.endpoint}/api/v1/login'
r = requests.post(
@@ -94,14 +92,14 @@ class HubApi:
'Path': owner_or_group,
'Name': name,
'ChineseName': chinese_name,
- 'Visibility': visibility,
+ 'Visibility': visibility, # server check
'License': license
},
cookies=cookies)
r.raise_for_status()
raise_on_error(r.json())
- d = r.json()
- return d['Data']['Name']
+ model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}'
+ return model_repo_url
def delete_model(self, model_id):
"""_summary_
@@ -209,25 +207,37 @@ class HubApi:
class ModelScopeConfig:
path_credential = expanduser('~/.modelscope/credentials')
- os.makedirs(path_credential, exist_ok=True)
+
+ @classmethod
+ def make_sure_credential_path_exist(cls):
+ os.makedirs(cls.path_credential, exist_ok=True)
@classmethod
def save_cookies(cls, cookies: CookieJar):
+ cls.make_sure_credential_path_exist()
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f:
pickle.dump(cookies, f)
@classmethod
def get_cookies(cls):
try:
- with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f:
- return pickle.load(f)
+ cookies_path = os.path.join(cls.path_credential, 'cookies')
+ with open(cookies_path, 'rb') as f:
+ cookies = pickle.load(f)
+ for cookie in cookies:
+ if cookie.is_expired():
+ logger.warn('Auth is expored, please re-login')
+ return None
+ return cookies
except FileNotFoundError:
- logger.warn("Auth token does not exist, you'll get authentication \
- error when downloading private model files. Please login first"
- )
+ logger.warn(
+ "Auth token does not exist, you'll get authentication error when downloading \
+ private model files. Please login first")
+ return None
@classmethod
def save_token(cls, token: str):
+ cls.make_sure_credential_path_exist()
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f:
f.write(token)
diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py
index 13ea709f..4b39d6e3 100644
--- a/modelscope/hub/errors.py
+++ b/modelscope/hub/errors.py
@@ -6,6 +6,10 @@ class RequestError(Exception):
pass
+class GitError(Exception):
+ pass
+
+
def is_ok(rsp):
""" Check the request is ok
diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py
index 5f079105..37f61814 100644
--- a/modelscope/hub/git.py
+++ b/modelscope/hub/git.py
@@ -1,82 +1,161 @@
-from threading import local
-from tkinter.messagebox import NO
-from typing import Union
+import subprocess
+from typing import List
+from xmlrpc.client import Boolean
from modelscope.utils.logger import get_logger
-from .constants import LOGGER_NAME
-from .utils._subprocess import run_subprocess
+from .errors import GitError
-logger = get_logger
+logger = get_logger()
-def git_clone(
- local_dir: str,
- repo_url: str,
-):
- # TODO: use "git clone" or "git lfs clone" according to git version
- # TODO: print stderr when subprocess fails
- run_subprocess(
- f'git clone {repo_url}'.split(),
- local_dir,
- True,
- )
+class Singleton(type):
+ _instances = {}
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(Singleton,
+ cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
-def git_checkout(
- local_dir: str,
- revsion: str,
-):
- run_subprocess(f'git checkout {revsion}'.split(), local_dir)
-
-def git_add(local_dir: str, ):
- run_subprocess(
- 'git add .'.split(),
- local_dir,
- True,
- )
-
-
-def git_commit(local_dir: str, commit_message: str):
- run_subprocess(
- 'git commit -v -m'.split() + [commit_message],
- local_dir,
- True,
- )
-
-
-def git_push(local_dir: str, branch: str):
- # check current branch
- cur_branch = git_current_branch(local_dir)
- if cur_branch != branch:
- logger.error(
- "You're trying to push to a different branch, please double check")
- return
-
- run_subprocess(
- f'git push origin {branch}'.split(),
- local_dir,
- True,
- )
-
-
-def git_current_branch(local_dir: str) -> Union[str, None]:
- """
- Get current branch name
-
- Args:
- local_dir(`str`): local model repo directory
-
- Returns
- branch name you're currently on
+class GitCommandWrapper(metaclass=Singleton):
+ """Some git operation wrapper
"""
- try:
- process = run_subprocess(
- 'git rev-parse --abbrev-ref HEAD'.split(),
- local_dir,
- True,
- )
-
- return str(process.stdout).strip()
- except Exception as e:
- raise e
+ default_git_path = 'git' # The default git command line
+
+ def __init__(self, path: str = None):
+ self.git_path = path or self.default_git_path
+
+ def _run_git_command(self, *args) -> subprocess.CompletedProcess:
+ """Run git command, if command return 0, return subprocess.response
+ otherwise raise GitError, message is stdout and stderr.
+
+ Raises:
+ GitError: Exception with stdout and stderr.
+
+ Returns:
+ subprocess.CompletedProcess: the command response
+ """
+ logger.info(' '.join(args))
+ response = subprocess.run(
+ [self.git_path, *args],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE) # compatible for python3.6
+ try:
+ response.check_returncode()
+ return response
+ except subprocess.CalledProcessError as error:
+ raise GitError(
+ 'stdout: %s, stderr: %s' %
+ (response.stdout.decode('utf8'), error.stderr.decode('utf8')))
+
+ def _add_token(self, token: str, url: str):
+ if token:
+ if '//oauth2' not in url:
+ url = url.replace('//', '//oauth2:%s@' % token)
+ return url
+
+ def remove_token_from_url(self, url: str):
+ if url and '//oauth2' in url:
+ start_index = url.find('oauth2')
+ end_index = url.find('@')
+ url = url[:start_index] + url[end_index + 1:]
+ return url
+
+ def is_lfs_installed(self):
+ cmd = ['lfs', 'env']
+ try:
+ self._run_git_command(*cmd)
+ return True
+ except GitError:
+ return False
+
+ def clone(self,
+ repo_base_dir: str,
+ token: str,
+ url: str,
+ repo_name: str,
+ branch: str = None):
+ """ git clone command wrapper.
+ For public project, token can None, private repo, there must token.
+
+ Args:
+ repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name
+ token (str): The git token, must be provided for private project.
+ url (str): The remote url
+ repo_name (str): The local repository path name.
+ branch (str, optional): _description_. Defaults to None.
+ """
+ url = self._add_token(token, url)
+ if branch:
+ clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url,
+ repo_name, branch)
+ else:
+ clone_args = '-C %s clone %s' % (repo_base_dir, url)
+ logger.debug(clone_args)
+ clone_args = clone_args.split(' ')
+ response = self._run_git_command(*clone_args)
+ logger.info(response.stdout.decode('utf8'))
+ return response
+
+ def add(self,
+ repo_dir: str,
+ files: List[str] = list(),
+ all_files: bool = False):
+ if all_files:
+ add_args = '-C %s add -A' % repo_dir
+ elif len(files) > 0:
+ files_str = ' '.join(files)
+ add_args = '-C %s add %s' % (repo_dir, files_str)
+ add_args = add_args.split(' ')
+ rsp = self._run_git_command(*add_args)
+ logger.info(rsp.stdout.decode('utf8'))
+ return rsp
+
+ def commit(self, repo_dir: str, message: str):
+ """Run git commit command
+
+ Args:
+ message (str): commit message.
+ """
+ commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message]
+ rsp = self._run_git_command(*commit_args)
+ logger.info(rsp.stdout.decode('utf8'))
+ return rsp
+
+ def checkout(self, repo_dir: str, revision: str):
+ cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision]
+ return self._run_git_command(*cmds)
+
+ def new_branch(self, repo_dir: str, revision: str):
+ cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision]
+ return self._run_git_command(*cmds)
+
+ def pull(self, repo_dir: str):
+ cmds = ['-C', repo_dir, 'pull']
+ return self._run_git_command(*cmds)
+
+ def push(self,
+ repo_dir: str,
+ token: str,
+ url: str,
+ local_branch: str,
+ remote_branch: str,
+ force: bool = False):
+ url = self._add_token(token, url)
+
+ push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch,
+ remote_branch)
+ if force:
+ push_args += ' -f'
+ push_args = push_args.split(' ')
+ rsp = self._run_git_command(*push_args)
+ logger.info(rsp.stdout.decode('utf8'))
+ return rsp
+
+ def get_repo_remote_url(self, repo_dir: str):
+ cmd_args = '-C %s config --get remote.origin.url' % repo_dir
+ cmd_args = cmd_args.split(' ')
+ rsp = self._run_git_command(*cmd_args)
+ url = rsp.stdout.decode('utf8')
+ return url.strip()
diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py
index 6367f903..d9322144 100644
--- a/modelscope/hub/repository.py
+++ b/modelscope/hub/repository.py
@@ -1,173 +1,97 @@
import os
-import subprocess
-from pathlib import Path
-from typing import Optional, Union
+from typing import List, Optional
+from modelscope.hub.errors import GitError
from modelscope.utils.logger import get_logger
from .api import ModelScopeConfig
from .constants import MODELSCOPE_URL_SCHEME
-from .git import git_add, git_checkout, git_clone, git_commit, git_push
-from .utils._subprocess import run_subprocess
+from .git import GitCommandWrapper
from .utils.utils import get_gitlab_domain
logger = get_logger()
class Repository:
+ """Representation local model git repository.
+ """
def __init__(
self,
- local_dir: str,
- clone_from: Optional[str] = None,
- auth_token: Optional[str] = None,
- private: Optional[bool] = False,
+ model_dir: str,
+ clone_from: str,
revision: Optional[str] = 'master',
+ auth_token: Optional[str] = None,
+ git_path: Optional[str] = None,
):
"""
Instantiate a Repository object by cloning the remote ModelScopeHub repo
Args:
- local_dir(`str`):
- local directory to store the model files
- clone_from(`Optional[str] = None`):
+ model_dir(`str`):
+ The model root directory.
+ clone_from:
model id in ModelScope-hub from which git clone
- You should ignore this parameter when `local_dir` is already a git repo
- auth_token(`Optional[str]`):
- token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
- as the token is already saved when you login the first time
- private(`Optional[bool]`):
- whether the model is private, default to False
revision(`Optional[str]`):
revision of the model you want to clone from. Can be any of a branch, tag or commit hash
+ auth_token(`Optional[str]`):
+ token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
+ as the token is already saved when you login the first time, if None, we will use saved token.
+ git_path:(`Optional[str]`):
+ The git command line path, if None, we use 'git'
"""
- logger.info('Instantiating Repository object...')
-
- # Create local directory if not exist
- os.makedirs(local_dir, exist_ok=True)
- self.local_dir = os.path.join(os.getcwd(), local_dir)
-
- self.private = private
-
- # Check git and git-lfs installation
- self.check_git_versions()
-
- # Retrieve auth token
- if not private and isinstance(auth_token, str):
- logger.warning(
- 'cloning a public repo with a token, which will be ignored')
- self.token = None
+ self.model_dir = model_dir
+ self.model_base_dir = os.path.dirname(model_dir)
+ self.model_repo_name = os.path.basename(model_dir)
+ if auth_token:
+ self.auth_token = auth_token
else:
- if isinstance(auth_token, str):
- self.token = auth_token
- else:
- self.token = ModelScopeConfig.get_token()
-
- if self.token is None:
- raise EnvironmentError(
- 'Token does not exist, the clone will fail for private repo.'
- 'Please login first.')
-
- # git clone
- if clone_from is not None:
- self.model_id = clone_from
- logger.info('cloning model repo to %s ...', self.local_dir)
- git_clone(self.local_dir, self.get_repo_url())
- else:
- if is_git_repo(self.local_dir):
- logger.debug('[Repository] is a valid git repo')
- else:
- raise ValueError(
- 'If not specifying `clone_from`, you need to pass Repository a'
- ' valid git clone.')
-
- # git checkout
- if isinstance(revision, str) and revision != 'master':
- git_checkout(revision)
-
- def push_to_hub(self,
- commit_message: str,
- revision: Optional[str] = 'master'):
- """
- Push changes changes to hub
-
- Args:
- commit_message(`str`):
- commit message describing the changes, it's mandatory
- revision(`Optional[str]`):
- remote branch you want to push to, default to `master`
-
-
- The function complains when local and remote branch are different, please be careful
-
-
- """
- git_add(self.local_dir)
- git_commit(self.local_dir, commit_message)
-
- logger.info('Pushing changes to repo...')
- git_push(self.local_dir, revision)
-
- # TODO: if git push fails, how to retry?
-
- def check_git_versions(self):
- """
- Checks that `git` and `git-lfs` can be run.
-
- Raises:
- `EnvironmentError`: if `git` or `git-lfs` are not installed.
- """
- try:
- git_version = run_subprocess('git --version'.split(),
- self.local_dir).stdout.strip()
- except FileNotFoundError:
- raise EnvironmentError(
- 'Looks like you do not have git installed, please install.')
+ self.auth_token = ModelScopeConfig.get_token()
+
+ git_wrapper = GitCommandWrapper()
+ if not git_wrapper.is_lfs_installed():
+ logger.error('git lfs is not installed, please install.')
+
+ self.git_wrapper = GitCommandWrapper(git_path)
+ os.makedirs(self.model_dir, exist_ok=True)
+ url = self._get_model_id_url(clone_from)
+ if os.listdir(self.model_dir): # directory not empty.
+ remote_url = self._get_remote_url()
+ remote_url = self.git_wrapper.remove_token_from_url(remote_url)
+ if remote_url and remote_url == url: # need not clone again
+ return
+ self.git_wrapper.clone(self.model_base_dir, self.auth_token, url,
+ self.model_repo_name, revision)
+
+ def _get_model_id_url(self, model_id):
+ url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}'
+ return url
+ def _get_remote_url(self):
try:
- lfs_version = run_subprocess('git-lfs --version'.split(),
- self.local_dir).stdout.strip()
- except FileNotFoundError:
- raise EnvironmentError(
- 'Looks like you do not have git-lfs installed, please install.'
- ' You can install from https://git-lfs.github.com/.'
- ' Then run `git lfs install` (you only have to do this once).')
- logger.info(git_version + '\n' + lfs_version)
-
- def get_repo_url(self) -> str:
- """
- Get repo url to clone, according whether the repo is private or not
+ remote = self.git_wrapper.get_repo_remote_url(self.model_dir)
+ except GitError:
+ remote = None
+ return remote
+
+ def push(self,
+ commit_message: str,
+ files: List[str] = list(),
+ all_files: bool = False,
+ branch: Optional[str] = 'master',
+ force: bool = False):
+ """Push local to remote, this method will do.
+ git add
+ git commit
+ git push
+ Args:
+ commit_message (str): commit message
+ revision (Optional[str], optional): which branch to push. Defaults to 'master'.
"""
- url = None
-
- if self.private:
- url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}'
- else:
- url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}'
-
- if not url:
- raise ValueError(
- 'Empty repo url, please check clone_from parameter')
-
- logger.debug('url to clone: %s', str(url))
-
- return url
-
-
-def is_git_repo(folder: Union[str, Path]) -> bool:
- """
- Check if the folder is the root or part of a git repository
-
- Args:
- folder (`str`):
- The folder in which to run the command.
-
- Returns:
- `bool`: `True` if the repository is part of a repository, `False`
- otherwise.
- """
- folder_exists = os.path.exists(os.path.join(folder, '.git'))
- git_branch = subprocess.run(
- 'git branch'.split(),
- cwd=folder,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- return folder_exists and git_branch.returncode == 0
+ url = self.git_wrapper.get_repo_remote_url(self.model_dir)
+ self.git_wrapper.add(self.model_dir, files, all_files)
+ self.git_wrapper.commit(self.model_dir, commit_message)
+ self.git_wrapper.push(
+ repo_dir=self.model_dir,
+ token=self.auth_token,
+ url=url,
+ local_branch=branch,
+ remote_branch=branch)
diff --git a/modelscope/hub/utils/_subprocess.py b/modelscope/hub/utils/_subprocess.py
deleted file mode 100644
index 77e9fc48..00000000
--- a/modelscope/hub/utils/_subprocess.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import subprocess
-from typing import List
-
-
-def run_subprocess(command: List[str],
- folder: str,
- check=True,
- **kwargs) -> subprocess.CompletedProcess:
- """
- Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
- please call `subprocess.run` manually in case you would like for them not to
- be captured.
-
- Args:
- command (`List[str]`):
- The command to execute as a list of strings.
- folder (`str`):
- The folder in which to run the command.
- check (`bool`, *optional*, defaults to `True`):
- Setting `check` to `True` will raise a `subprocess.CalledProcessError`
- when the subprocess has a non-zero exit code.
- kwargs (`Dict[str]`):
- Keyword arguments to be passed to the `subprocess.run` underlying command.
-
- Returns:
- `subprocess.CompletedProcess`: The completed process.
- """
- if isinstance(command, str):
- raise ValueError(
- '`run_subprocess` should be called with a list of strings.')
-
- return subprocess.run(
- command,
- stderr=subprocess.PIPE,
- stdout=subprocess.PIPE,
- check=check,
- encoding='utf-8',
- cwd=folder,
- **kwargs,
- )
diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py
index d44cd7c1..e0adc013 100644
--- a/tests/hub/test_hub_operation.py
+++ b/tests/hub/test_hub_operation.py
@@ -1,14 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
-import subprocess
import tempfile
import unittest
import uuid
-from modelscope.hub.api import HubApi, ModelScopeConfig
+from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
+from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
-from modelscope.hub.utils.utils import get_gitlab_domain
USER_NAME = 'maasadmin'
PASSWORD = '12345678'
@@ -17,40 +16,7 @@ model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'
DEFAULT_GIT_PATH = 'git'
-
-class GitError(Exception):
- pass
-
-
-# TODO make thest git operation to git library after merge code.
-def run_git_command(git_path, *args) -> subprocess.CompletedProcess:
- response = subprocess.run([git_path, *args], capture_output=True)
- try:
- response.check_returncode()
- return response.stdout.decode('utf8')
- except subprocess.CalledProcessError as error:
- raise GitError(error.stderr.decode('utf8'))
-
-
-# for public project, token can None, private repo, there must token.
-def clone(local_dir: str, token: str, url: str):
- url = url.replace('//', '//oauth2:%s@' % token)
- clone_args = '-C %s clone %s' % (local_dir, url)
- clone_args = clone_args.split(' ')
- stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args)
- print('stdout: %s' % stdout)
-
-
-def push(local_dir: str, token: str, url: str):
- url = url.replace('//', '//oauth2:%s@' % token)
- push_args = '-C %s push %s' % (local_dir, url)
- push_args = push_args.split(' ')
- stdout = run_git_command(DEFAULT_GIT_PATH, *push_args)
- print('stdout: %s' % stdout)
-
-
-sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
-download_model_file_name = 'mnist-12.onnx'
+download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase):
@@ -67,6 +33,13 @@ class HubOperationTest(unittest.TestCase):
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
+ temporary_dir = tempfile.mkdtemp()
+ self.model_dir = os.path.join(temporary_dir, self.model_name)
+ repo = Repository(self.model_dir, clone_from=self.model_id)
+ os.chdir(self.model_dir)
+ os.system("echo 'testtest'>%s"
+ % os.path.join(self.model_dir, 'test.bin'))
+ repo.push('add model', all_files=True)
def tearDown(self):
os.chdir(self.old_cwd)
@@ -83,43 +56,10 @@ class HubOperationTest(unittest.TestCase):
else:
raise
- # Note that this can be done via git operation once model repo
- # has been created. Git-Op is the RECOMMENDED model upload approach
- def test_model_upload(self):
- url = f'http://{get_gitlab_domain()}/{self.model_id}'
- print(url)
- temporary_dir = tempfile.mkdtemp()
- os.chdir(temporary_dir)
- cmd_args = 'clone %s' % url
- cmd_args = cmd_args.split(' ')
- out = run_git_command('git', *cmd_args)
- print(out)
- repo_dir = os.path.join(temporary_dir, self.model_name)
- os.chdir(repo_dir)
- os.system('touch file1')
- os.system('git add file1')
- os.system("git commit -m 'Test'")
- token = ModelScopeConfig.get_token()
- push(repo_dir, token, url)
-
def test_download_single_file(self):
- url = f'http://{get_gitlab_domain()}/{self.model_id}'
- print(url)
- temporary_dir = tempfile.mkdtemp()
- os.chdir(temporary_dir)
- os.system('git clone %s' % url)
- repo_dir = os.path.join(temporary_dir, self.model_name)
- os.chdir(repo_dir)
- os.system('wget %s' % sample_model_url)
- os.system('git add .')
- os.system("git commit -m 'Add file'")
- token = ModelScopeConfig.get_token()
- push(repo_dir, token, url)
- assert os.path.exists(
- os.path.join(temporary_dir, self.model_name,
- download_model_file_name))
downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
+ assert os.path.exists(downloaded_file)
mdtime1 = os.path.getmtime(downloaded_file)
# download again
downloaded_file = model_file_download(
@@ -128,18 +68,6 @@ class HubOperationTest(unittest.TestCase):
assert mdtime1 == mdtime2
def test_snapshot_download(self):
- url = f'http://{get_gitlab_domain()}/{self.model_id}'
- print(url)
- temporary_dir = tempfile.mkdtemp()
- os.chdir(temporary_dir)
- os.system('git clone %s' % url)
- repo_dir = os.path.join(temporary_dir, self.model_name)
- os.chdir(repo_dir)
- os.system('wget %s' % sample_model_url)
- os.system('git add .')
- os.system("git commit -m 'Add file'")
- token = ModelScopeConfig.get_token()
- push(repo_dir, token, url)
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py
new file mode 100644
index 00000000..b6e3536c
--- /dev/null
+++ b/tests/hub/test_hub_private_repository.py
@@ -0,0 +1,76 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import tempfile
+import unittest
+import uuid
+
+from modelscope.hub.api import HubApi
+from modelscope.hub.errors import GitError
+from modelscope.hub.repository import Repository
+
+USER_NAME = 'maasadmin'
+PASSWORD = '12345678'
+
+USER_NAME2 = 'sdkdev'
+model_chinese_name = '达摩卡通化模型'
+model_org = 'unittest'
+DEFAULT_GIT_PATH = 'git'
+
+sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
+download_model_file_name = 'mnist-12.onnx'
+
+
+class HubPrivateRepositoryTest(unittest.TestCase):
+
+ def setUp(self):
+ self.old_cwd = os.getcwd()
+ self.api = HubApi()
+ # note this is temporary before official account management is ready
+ self.token, _ = self.api.login(USER_NAME, PASSWORD)
+ self.model_name = uuid.uuid4().hex
+ self.model_id = '%s/%s' % (model_org, self.model_name)
+ self.api.create_model(
+ model_id=self.model_id,
+ chinese_name=model_chinese_name,
+ visibility=1, # 1-private, 5-public
+ license='apache-2.0')
+
+ def tearDown(self):
+ self.api.login(USER_NAME, PASSWORD)
+ os.chdir(self.old_cwd)
+ self.api.delete_model(model_id=self.model_id)
+
+ def test_clone_private_repo_no_permission(self):
+ token, _ = self.api.login(USER_NAME2, PASSWORD)
+ temporary_dir = tempfile.mkdtemp()
+ local_dir = os.path.join(temporary_dir, self.model_name)
+ with self.assertRaises(GitError) as cm:
+ Repository(local_dir, clone_from=self.model_id, auth_token=token)
+
+ print(cm.exception)
+ assert not os.path.exists(os.path.join(local_dir, 'README.md'))
+
+ def test_clone_private_repo_has_permission(self):
+ temporary_dir = tempfile.mkdtemp()
+ local_dir = os.path.join(temporary_dir, self.model_name)
+ repo1 = Repository(
+ local_dir, clone_from=self.model_id, auth_token=self.token)
+ print(repo1.model_dir)
+ assert os.path.exists(os.path.join(local_dir, 'README.md'))
+
+ def test_initlize_repo_multiple_times(self):
+ temporary_dir = tempfile.mkdtemp()
+ local_dir = os.path.join(temporary_dir, self.model_name)
+ repo1 = Repository(
+ local_dir, clone_from=self.model_id, auth_token=self.token)
+ print(repo1.model_dir)
+ assert os.path.exists(os.path.join(local_dir, 'README.md'))
+ repo2 = Repository(
+ local_dir, clone_from=self.model_id,
+ auth_token=self.token) # skip clone
+ print(repo2.model_dir)
+ assert repo1.model_dir == repo2.model_dir
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py
new file mode 100644
index 00000000..7b1cc751
--- /dev/null
+++ b/tests/hub/test_hub_repository.py
@@ -0,0 +1,107 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import time
+import unittest
+import uuid
+from os.path import expanduser
+
+from requests import delete
+
+from modelscope.hub.api import HubApi
+from modelscope.hub.errors import NotExistError
+from modelscope.hub.file_download import model_file_download
+from modelscope.hub.repository import Repository
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+logger.setLevel('DEBUG')
+USER_NAME = 'maasadmin'
+PASSWORD = '12345678'
+
+model_chinese_name = '达摩卡通化模型'
+model_org = 'unittest'
+DEFAULT_GIT_PATH = 'git'
+
+download_model_file_name = 'mnist-12.onnx'
+
+
+def delete_credential():
+ path_credential = expanduser('~/.modelscope/credentials')
+ shutil.rmtree(path_credential)
+
+
+def delete_stored_git_credential(user):
+ credential_path = expanduser('~/.git-credentials')
+ if os.path.exists(credential_path):
+ with open(credential_path, 'r+') as f:
+ lines = f.readlines()
+ for line in lines:
+ if user in line:
+ lines.remove(line)
+ f.seek(0)
+ f.write(''.join(lines))
+ f.truncate()
+
+
+class HubRepositoryTest(unittest.TestCase):
+
+ def setUp(self):
+ self.api = HubApi()
+ # note this is temporary before official account management is ready
+ self.api.login(USER_NAME, PASSWORD)
+ self.model_name = uuid.uuid4().hex
+ self.model_id = '%s/%s' % (model_org, self.model_name)
+ self.api.create_model(
+ model_id=self.model_id,
+ chinese_name=model_chinese_name,
+ visibility=5, # 1-private, 5-public
+ license='apache-2.0')
+ temporary_dir = tempfile.mkdtemp()
+ self.model_dir = os.path.join(temporary_dir, self.model_name)
+
+ def tearDown(self):
+ self.api.delete_model(model_id=self.model_id)
+
+ def test_clone_repo(self):
+ Repository(self.model_dir, clone_from=self.model_id)
+ assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
+
+ def test_clone_public_model_without_token(self):
+ delete_credential()
+ delete_stored_git_credential(USER_NAME)
+ Repository(self.model_dir, clone_from=self.model_id)
+ assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
+ self.api.login(USER_NAME, PASSWORD) # re-login for delete
+
+ def test_push_all(self):
+ repo = Repository(self.model_dir, clone_from=self.model_id)
+ assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
+ os.chdir(self.model_dir)
+ os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
+ os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
+ repo.push('test', all_files=True)
+ add1 = model_file_download(self.model_id, 'add1.py')
+ assert os.path.exists(add1)
+ add2 = model_file_download(self.model_id, 'add2.py')
+ assert os.path.exists(add2)
+
+ def test_push_files(self):
+ repo = Repository(self.model_dir, clone_from=self.model_id)
+ assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
+ os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
+ os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
+ os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py'))
+ repo.push('test', files=['add1.py', 'add2.py'], all_files=False)
+ add1 = model_file_download(self.model_id, 'add1.py')
+ assert os.path.exists(add1)
+ add2 = model_file_download(self.model_id, 'add2.py')
+ assert os.path.exists(add2)
+ with self.assertRaises(NotExistError) as cm:
+ model_file_download(self.model_id, 'add3.py')
+ print(cm.exception)
+
+
+if __name__ == '__main__':
+ unittest.main()