From 6991620f59c20ac386a815eb6d842adde3cedd07 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 24 Jun 2022 16:43:32 +0800 Subject: [PATCH] [to #42698276]fix: git repo operations supports, gitlab token certification support. --- modelscope/hub/api.py | 38 ++-- modelscope/hub/errors.py | 4 + modelscope/hub/git.py | 225 +++++++++++++++-------- modelscope/hub/repository.py | 216 +++++++--------------- modelscope/hub/utils/_subprocess.py | 40 ---- tests/hub/test_hub_operation.py | 94 ++-------- tests/hub/test_hub_private_repository.py | 76 ++++++++ tests/hub/test_hub_repository.py | 107 +++++++++++ 8 files changed, 444 insertions(+), 356 deletions(-) delete mode 100644 modelscope/hub/utils/_subprocess.py create mode 100644 tests/hub/test_hub_private_repository.py create mode 100644 tests/hub/test_hub_repository.py diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 104eafbd..f4f31280 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union import requests from modelscope.utils.logger import get_logger -from .constants import LOGGER_NAME +from .constants import MODELSCOPE_URL_SCHEME from .errors import NotExistError, is_ok, raise_on_error -from .utils.utils import get_endpoint, model_id_to_group_owner_name +from .utils.utils import (get_endpoint, get_gitlab_domain, + model_id_to_group_owner_name) logger = get_logger() @@ -40,9 +41,6 @@ class HubApi: You only have to login once within 30 days. - - TODO: handle cookies expire - """ path = f'{self.endpoint}/api/v1/login' r = requests.post( @@ -94,14 +92,14 @@ class HubApi: 'Path': owner_or_group, 'Name': name, 'ChineseName': chinese_name, - 'Visibility': visibility, + 'Visibility': visibility, # server check 'License': license }, cookies=cookies) r.raise_for_status() raise_on_error(r.json()) - d = r.json() - return d['Data']['Name'] + model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' + return model_repo_url def delete_model(self, model_id): """_summary_ @@ -209,25 +207,37 @@ class HubApi: class ModelScopeConfig: path_credential = expanduser('~/.modelscope/credentials') - os.makedirs(path_credential, exist_ok=True) + + @classmethod + def make_sure_credential_path_exist(cls): + os.makedirs(cls.path_credential, exist_ok=True) @classmethod def save_cookies(cls, cookies: CookieJar): + cls.make_sure_credential_path_exist() with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: pickle.dump(cookies, f) @classmethod def get_cookies(cls): try: - with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: - return pickle.load(f) + cookies_path = os.path.join(cls.path_credential, 'cookies') + with open(cookies_path, 'rb') as f: + cookies = pickle.load(f) + for cookie in cookies: + if cookie.is_expired(): + logger.warn('Auth is expored, please re-login') + return None + return cookies except FileNotFoundError: - logger.warn("Auth token does not exist, you'll get authentication \ - error when downloading private model files. Please login first" - ) + logger.warn( + "Auth token does not exist, you'll get authentication error when downloading \ + private model files. Please login first") + return None @classmethod def save_token(cls, token: str): + cls.make_sure_credential_path_exist() with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: f.write(token) diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index 13ea709f..4b39d6e3 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -6,6 +6,10 @@ class RequestError(Exception): pass +class GitError(Exception): + pass + + def is_ok(rsp): """ Check the request is ok diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index 5f079105..37f61814 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -1,82 +1,161 @@ -from threading import local -from tkinter.messagebox import NO -from typing import Union +import subprocess +from typing import List +from xmlrpc.client import Boolean from modelscope.utils.logger import get_logger -from .constants import LOGGER_NAME -from .utils._subprocess import run_subprocess +from .errors import GitError -logger = get_logger +logger = get_logger() -def git_clone( - local_dir: str, - repo_url: str, -): - # TODO: use "git clone" or "git lfs clone" according to git version - # TODO: print stderr when subprocess fails - run_subprocess( - f'git clone {repo_url}'.split(), - local_dir, - True, - ) +class Singleton(type): + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] -def git_checkout( - local_dir: str, - revsion: str, -): - run_subprocess(f'git checkout {revsion}'.split(), local_dir) - -def git_add(local_dir: str, ): - run_subprocess( - 'git add .'.split(), - local_dir, - True, - ) - - -def git_commit(local_dir: str, commit_message: str): - run_subprocess( - 'git commit -v -m'.split() + [commit_message], - local_dir, - True, - ) - - -def git_push(local_dir: str, branch: str): - # check current branch - cur_branch = git_current_branch(local_dir) - if cur_branch != branch: - logger.error( - "You're trying to push to a different branch, please double check") - return - - run_subprocess( - f'git push origin {branch}'.split(), - local_dir, - True, - ) - - -def git_current_branch(local_dir: str) -> Union[str, None]: - """ - Get current branch name - - Args: - local_dir(`str`): local model repo directory - - Returns - branch name you're currently on +class GitCommandWrapper(metaclass=Singleton): + """Some git operation wrapper """ - try: - process = run_subprocess( - 'git rev-parse --abbrev-ref HEAD'.split(), - local_dir, - True, - ) - - return str(process.stdout).strip() - except Exception as e: - raise e + default_git_path = 'git' # The default git command line + + def __init__(self, path: str = None): + self.git_path = path or self.default_git_path + + def _run_git_command(self, *args) -> subprocess.CompletedProcess: + """Run git command, if command return 0, return subprocess.response + otherwise raise GitError, message is stdout and stderr. + + Raises: + GitError: Exception with stdout and stderr. + + Returns: + subprocess.CompletedProcess: the command response + """ + logger.info(' '.join(args)) + response = subprocess.run( + [self.git_path, *args], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) # compatible for python3.6 + try: + response.check_returncode() + return response + except subprocess.CalledProcessError as error: + raise GitError( + 'stdout: %s, stderr: %s' % + (response.stdout.decode('utf8'), error.stderr.decode('utf8'))) + + def _add_token(self, token: str, url: str): + if token: + if '//oauth2' not in url: + url = url.replace('//', '//oauth2:%s@' % token) + return url + + def remove_token_from_url(self, url: str): + if url and '//oauth2' in url: + start_index = url.find('oauth2') + end_index = url.find('@') + url = url[:start_index] + url[end_index + 1:] + return url + + def is_lfs_installed(self): + cmd = ['lfs', 'env'] + try: + self._run_git_command(*cmd) + return True + except GitError: + return False + + def clone(self, + repo_base_dir: str, + token: str, + url: str, + repo_name: str, + branch: str = None): + """ git clone command wrapper. + For public project, token can None, private repo, there must token. + + Args: + repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name + token (str): The git token, must be provided for private project. + url (str): The remote url + repo_name (str): The local repository path name. + branch (str, optional): _description_. Defaults to None. + """ + url = self._add_token(token, url) + if branch: + clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, + repo_name, branch) + else: + clone_args = '-C %s clone %s' % (repo_base_dir, url) + logger.debug(clone_args) + clone_args = clone_args.split(' ') + response = self._run_git_command(*clone_args) + logger.info(response.stdout.decode('utf8')) + return response + + def add(self, + repo_dir: str, + files: List[str] = list(), + all_files: bool = False): + if all_files: + add_args = '-C %s add -A' % repo_dir + elif len(files) > 0: + files_str = ' '.join(files) + add_args = '-C %s add %s' % (repo_dir, files_str) + add_args = add_args.split(' ') + rsp = self._run_git_command(*add_args) + logger.info(rsp.stdout.decode('utf8')) + return rsp + + def commit(self, repo_dir: str, message: str): + """Run git commit command + + Args: + message (str): commit message. + """ + commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] + rsp = self._run_git_command(*commit_args) + logger.info(rsp.stdout.decode('utf8')) + return rsp + + def checkout(self, repo_dir: str, revision: str): + cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] + return self._run_git_command(*cmds) + + def new_branch(self, repo_dir: str, revision: str): + cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] + return self._run_git_command(*cmds) + + def pull(self, repo_dir: str): + cmds = ['-C', repo_dir, 'pull'] + return self._run_git_command(*cmds) + + def push(self, + repo_dir: str, + token: str, + url: str, + local_branch: str, + remote_branch: str, + force: bool = False): + url = self._add_token(token, url) + + push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, + remote_branch) + if force: + push_args += ' -f' + push_args = push_args.split(' ') + rsp = self._run_git_command(*push_args) + logger.info(rsp.stdout.decode('utf8')) + return rsp + + def get_repo_remote_url(self, repo_dir: str): + cmd_args = '-C %s config --get remote.origin.url' % repo_dir + cmd_args = cmd_args.split(' ') + rsp = self._run_git_command(*cmd_args) + url = rsp.stdout.decode('utf8') + return url.strip() diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py index 6367f903..d9322144 100644 --- a/modelscope/hub/repository.py +++ b/modelscope/hub/repository.py @@ -1,173 +1,97 @@ import os -import subprocess -from pathlib import Path -from typing import Optional, Union +from typing import List, Optional +from modelscope.hub.errors import GitError from modelscope.utils.logger import get_logger from .api import ModelScopeConfig from .constants import MODELSCOPE_URL_SCHEME -from .git import git_add, git_checkout, git_clone, git_commit, git_push -from .utils._subprocess import run_subprocess +from .git import GitCommandWrapper from .utils.utils import get_gitlab_domain logger = get_logger() class Repository: + """Representation local model git repository. + """ def __init__( self, - local_dir: str, - clone_from: Optional[str] = None, - auth_token: Optional[str] = None, - private: Optional[bool] = False, + model_dir: str, + clone_from: str, revision: Optional[str] = 'master', + auth_token: Optional[str] = None, + git_path: Optional[str] = None, ): """ Instantiate a Repository object by cloning the remote ModelScopeHub repo Args: - local_dir(`str`): - local directory to store the model files - clone_from(`Optional[str] = None`): + model_dir(`str`): + The model root directory. + clone_from: model id in ModelScope-hub from which git clone - You should ignore this parameter when `local_dir` is already a git repo - auth_token(`Optional[str]`): - token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter - as the token is already saved when you login the first time - private(`Optional[bool]`): - whether the model is private, default to False revision(`Optional[str]`): revision of the model you want to clone from. Can be any of a branch, tag or commit hash + auth_token(`Optional[str]`): + token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter + as the token is already saved when you login the first time, if None, we will use saved token. + git_path:(`Optional[str]`): + The git command line path, if None, we use 'git' """ - logger.info('Instantiating Repository object...') - - # Create local directory if not exist - os.makedirs(local_dir, exist_ok=True) - self.local_dir = os.path.join(os.getcwd(), local_dir) - - self.private = private - - # Check git and git-lfs installation - self.check_git_versions() - - # Retrieve auth token - if not private and isinstance(auth_token, str): - logger.warning( - 'cloning a public repo with a token, which will be ignored') - self.token = None + self.model_dir = model_dir + self.model_base_dir = os.path.dirname(model_dir) + self.model_repo_name = os.path.basename(model_dir) + if auth_token: + self.auth_token = auth_token else: - if isinstance(auth_token, str): - self.token = auth_token - else: - self.token = ModelScopeConfig.get_token() - - if self.token is None: - raise EnvironmentError( - 'Token does not exist, the clone will fail for private repo.' - 'Please login first.') - - # git clone - if clone_from is not None: - self.model_id = clone_from - logger.info('cloning model repo to %s ...', self.local_dir) - git_clone(self.local_dir, self.get_repo_url()) - else: - if is_git_repo(self.local_dir): - logger.debug('[Repository] is a valid git repo') - else: - raise ValueError( - 'If not specifying `clone_from`, you need to pass Repository a' - ' valid git clone.') - - # git checkout - if isinstance(revision, str) and revision != 'master': - git_checkout(revision) - - def push_to_hub(self, - commit_message: str, - revision: Optional[str] = 'master'): - """ - Push changes changes to hub - - Args: - commit_message(`str`): - commit message describing the changes, it's mandatory - revision(`Optional[str]`): - remote branch you want to push to, default to `master` - - - The function complains when local and remote branch are different, please be careful - - - """ - git_add(self.local_dir) - git_commit(self.local_dir, commit_message) - - logger.info('Pushing changes to repo...') - git_push(self.local_dir, revision) - - # TODO: if git push fails, how to retry? - - def check_git_versions(self): - """ - Checks that `git` and `git-lfs` can be run. - - Raises: - `EnvironmentError`: if `git` or `git-lfs` are not installed. - """ - try: - git_version = run_subprocess('git --version'.split(), - self.local_dir).stdout.strip() - except FileNotFoundError: - raise EnvironmentError( - 'Looks like you do not have git installed, please install.') + self.auth_token = ModelScopeConfig.get_token() + + git_wrapper = GitCommandWrapper() + if not git_wrapper.is_lfs_installed(): + logger.error('git lfs is not installed, please install.') + + self.git_wrapper = GitCommandWrapper(git_path) + os.makedirs(self.model_dir, exist_ok=True) + url = self._get_model_id_url(clone_from) + if os.listdir(self.model_dir): # directory not empty. + remote_url = self._get_remote_url() + remote_url = self.git_wrapper.remove_token_from_url(remote_url) + if remote_url and remote_url == url: # need not clone again + return + self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, + self.model_repo_name, revision) + + def _get_model_id_url(self, model_id): + url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' + return url + def _get_remote_url(self): try: - lfs_version = run_subprocess('git-lfs --version'.split(), - self.local_dir).stdout.strip() - except FileNotFoundError: - raise EnvironmentError( - 'Looks like you do not have git-lfs installed, please install.' - ' You can install from https://git-lfs.github.com/.' - ' Then run `git lfs install` (you only have to do this once).') - logger.info(git_version + '\n' + lfs_version) - - def get_repo_url(self) -> str: - """ - Get repo url to clone, according whether the repo is private or not + remote = self.git_wrapper.get_repo_remote_url(self.model_dir) + except GitError: + remote = None + return remote + + def push(self, + commit_message: str, + files: List[str] = list(), + all_files: bool = False, + branch: Optional[str] = 'master', + force: bool = False): + """Push local to remote, this method will do. + git add + git commit + git push + Args: + commit_message (str): commit message + revision (Optional[str], optional): which branch to push. Defaults to 'master'. """ - url = None - - if self.private: - url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' - else: - url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' - - if not url: - raise ValueError( - 'Empty repo url, please check clone_from parameter') - - logger.debug('url to clone: %s', str(url)) - - return url - - -def is_git_repo(folder: Union[str, Path]) -> bool: - """ - Check if the folder is the root or part of a git repository - - Args: - folder (`str`): - The folder in which to run the command. - - Returns: - `bool`: `True` if the repository is part of a repository, `False` - otherwise. - """ - folder_exists = os.path.exists(os.path.join(folder, '.git')) - git_branch = subprocess.run( - 'git branch'.split(), - cwd=folder, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - return folder_exists and git_branch.returncode == 0 + url = self.git_wrapper.get_repo_remote_url(self.model_dir) + self.git_wrapper.add(self.model_dir, files, all_files) + self.git_wrapper.commit(self.model_dir, commit_message) + self.git_wrapper.push( + repo_dir=self.model_dir, + token=self.auth_token, + url=url, + local_branch=branch, + remote_branch=branch) diff --git a/modelscope/hub/utils/_subprocess.py b/modelscope/hub/utils/_subprocess.py deleted file mode 100644 index 77e9fc48..00000000 --- a/modelscope/hub/utils/_subprocess.py +++ /dev/null @@ -1,40 +0,0 @@ -import subprocess -from typing import List - - -def run_subprocess(command: List[str], - folder: str, - check=True, - **kwargs) -> subprocess.CompletedProcess: - """ - Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, - please call `subprocess.run` manually in case you would like for them not to - be captured. - - Args: - command (`List[str]`): - The command to execute as a list of strings. - folder (`str`): - The folder in which to run the command. - check (`bool`, *optional*, defaults to `True`): - Setting `check` to `True` will raise a `subprocess.CalledProcessError` - when the subprocess has a non-zero exit code. - kwargs (`Dict[str]`): - Keyword arguments to be passed to the `subprocess.run` underlying command. - - Returns: - `subprocess.CompletedProcess`: The completed process. - """ - if isinstance(command, str): - raise ValueError( - '`run_subprocess` should be called with a list of strings.') - - return subprocess.run( - command, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - check=check, - encoding='utf-8', - cwd=folder, - **kwargs, - ) diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index d44cd7c1..e0adc013 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -1,14 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import subprocess import tempfile import unittest import uuid -from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.api import HubApi from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository from modelscope.hub.snapshot_download import snapshot_download -from modelscope.hub.utils.utils import get_gitlab_domain USER_NAME = 'maasadmin' PASSWORD = '12345678' @@ -17,40 +16,7 @@ model_chinese_name = '达摩卡通化模型' model_org = 'unittest' DEFAULT_GIT_PATH = 'git' - -class GitError(Exception): - pass - - -# TODO make thest git operation to git library after merge code. -def run_git_command(git_path, *args) -> subprocess.CompletedProcess: - response = subprocess.run([git_path, *args], capture_output=True) - try: - response.check_returncode() - return response.stdout.decode('utf8') - except subprocess.CalledProcessError as error: - raise GitError(error.stderr.decode('utf8')) - - -# for public project, token can None, private repo, there must token. -def clone(local_dir: str, token: str, url: str): - url = url.replace('//', '//oauth2:%s@' % token) - clone_args = '-C %s clone %s' % (local_dir, url) - clone_args = clone_args.split(' ') - stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) - print('stdout: %s' % stdout) - - -def push(local_dir: str, token: str, url: str): - url = url.replace('//', '//oauth2:%s@' % token) - push_args = '-C %s push %s' % (local_dir, url) - push_args = push_args.split(' ') - stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) - print('stdout: %s' % stdout) - - -sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' -download_model_file_name = 'mnist-12.onnx' +download_model_file_name = 'test.bin' class HubOperationTest(unittest.TestCase): @@ -67,6 +33,13 @@ class HubOperationTest(unittest.TestCase): chinese_name=model_chinese_name, visibility=5, # 1-private, 5-public license='apache-2.0') + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + repo = Repository(self.model_dir, clone_from=self.model_id) + os.chdir(self.model_dir) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, 'test.bin')) + repo.push('add model', all_files=True) def tearDown(self): os.chdir(self.old_cwd) @@ -83,43 +56,10 @@ class HubOperationTest(unittest.TestCase): else: raise - # Note that this can be done via git operation once model repo - # has been created. Git-Op is the RECOMMENDED model upload approach - def test_model_upload(self): - url = f'http://{get_gitlab_domain()}/{self.model_id}' - print(url) - temporary_dir = tempfile.mkdtemp() - os.chdir(temporary_dir) - cmd_args = 'clone %s' % url - cmd_args = cmd_args.split(' ') - out = run_git_command('git', *cmd_args) - print(out) - repo_dir = os.path.join(temporary_dir, self.model_name) - os.chdir(repo_dir) - os.system('touch file1') - os.system('git add file1') - os.system("git commit -m 'Test'") - token = ModelScopeConfig.get_token() - push(repo_dir, token, url) - def test_download_single_file(self): - url = f'http://{get_gitlab_domain()}/{self.model_id}' - print(url) - temporary_dir = tempfile.mkdtemp() - os.chdir(temporary_dir) - os.system('git clone %s' % url) - repo_dir = os.path.join(temporary_dir, self.model_name) - os.chdir(repo_dir) - os.system('wget %s' % sample_model_url) - os.system('git add .') - os.system("git commit -m 'Add file'") - token = ModelScopeConfig.get_token() - push(repo_dir, token, url) - assert os.path.exists( - os.path.join(temporary_dir, self.model_name, - download_model_file_name)) downloaded_file = model_file_download( model_id=self.model_id, file_path=download_model_file_name) + assert os.path.exists(downloaded_file) mdtime1 = os.path.getmtime(downloaded_file) # download again downloaded_file = model_file_download( @@ -128,18 +68,6 @@ class HubOperationTest(unittest.TestCase): assert mdtime1 == mdtime2 def test_snapshot_download(self): - url = f'http://{get_gitlab_domain()}/{self.model_id}' - print(url) - temporary_dir = tempfile.mkdtemp() - os.chdir(temporary_dir) - os.system('git clone %s' % url) - repo_dir = os.path.join(temporary_dir, self.model_name) - os.chdir(repo_dir) - os.system('wget %s' % sample_model_url) - os.system('git add .') - os.system("git commit -m 'Add file'") - token = ModelScopeConfig.get_token() - push(repo_dir, token, url) snapshot_path = snapshot_download(model_id=self.model_id) downloaded_file_path = os.path.join(snapshot_path, download_model_file_name) diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py new file mode 100644 index 00000000..b6e3536c --- /dev/null +++ b/tests/hub/test_hub_private_repository.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid + +from modelscope.hub.api import HubApi +from modelscope.hub.errors import GitError +from modelscope.hub.repository import Repository + +USER_NAME = 'maasadmin' +PASSWORD = '12345678' + +USER_NAME2 = 'sdkdev' +model_chinese_name = '达摩卡通化模型' +model_org = 'unittest' +DEFAULT_GIT_PATH = 'git' + +sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' +download_model_file_name = 'mnist-12.onnx' + + +class HubPrivateRepositoryTest(unittest.TestCase): + + def setUp(self): + self.old_cwd = os.getcwd() + self.api = HubApi() + # note this is temporary before official account management is ready + self.token, _ = self.api.login(USER_NAME, PASSWORD) + self.model_name = uuid.uuid4().hex + self.model_id = '%s/%s' % (model_org, self.model_name) + self.api.create_model( + model_id=self.model_id, + chinese_name=model_chinese_name, + visibility=1, # 1-private, 5-public + license='apache-2.0') + + def tearDown(self): + self.api.login(USER_NAME, PASSWORD) + os.chdir(self.old_cwd) + self.api.delete_model(model_id=self.model_id) + + def test_clone_private_repo_no_permission(self): + token, _ = self.api.login(USER_NAME2, PASSWORD) + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + with self.assertRaises(GitError) as cm: + Repository(local_dir, clone_from=self.model_id, auth_token=token) + + print(cm.exception) + assert not os.path.exists(os.path.join(local_dir, 'README.md')) + + def test_clone_private_repo_has_permission(self): + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + repo1 = Repository( + local_dir, clone_from=self.model_id, auth_token=self.token) + print(repo1.model_dir) + assert os.path.exists(os.path.join(local_dir, 'README.md')) + + def test_initlize_repo_multiple_times(self): + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + repo1 = Repository( + local_dir, clone_from=self.model_id, auth_token=self.token) + print(repo1.model_dir) + assert os.path.exists(os.path.join(local_dir, 'README.md')) + repo2 = Repository( + local_dir, clone_from=self.model_id, + auth_token=self.token) # skip clone + print(repo2.model_dir) + assert repo1.model_dir == repo2.model_dir + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py new file mode 100644 index 00000000..7b1cc751 --- /dev/null +++ b/tests/hub/test_hub_repository.py @@ -0,0 +1,107 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import time +import unittest +import uuid +from os.path import expanduser + +from requests import delete + +from modelscope.hub.api import HubApi +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.utils.logger import get_logger + +logger = get_logger() +logger.setLevel('DEBUG') +USER_NAME = 'maasadmin' +PASSWORD = '12345678' + +model_chinese_name = '达摩卡通化模型' +model_org = 'unittest' +DEFAULT_GIT_PATH = 'git' + +download_model_file_name = 'mnist-12.onnx' + + +def delete_credential(): + path_credential = expanduser('~/.modelscope/credentials') + shutil.rmtree(path_credential) + + +def delete_stored_git_credential(user): + credential_path = expanduser('~/.git-credentials') + if os.path.exists(credential_path): + with open(credential_path, 'r+') as f: + lines = f.readlines() + for line in lines: + if user in line: + lines.remove(line) + f.seek(0) + f.write(''.join(lines)) + f.truncate() + + +class HubRepositoryTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + # note this is temporary before official account management is ready + self.api.login(USER_NAME, PASSWORD) + self.model_name = uuid.uuid4().hex + self.model_id = '%s/%s' % (model_org, self.model_name) + self.api.create_model( + model_id=self.model_id, + chinese_name=model_chinese_name, + visibility=5, # 1-private, 5-public + license='apache-2.0') + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def test_clone_repo(self): + Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, 'README.md')) + + def test_clone_public_model_without_token(self): + delete_credential() + delete_stored_git_credential(USER_NAME) + Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, 'README.md')) + self.api.login(USER_NAME, PASSWORD) # re-login for delete + + def test_push_all(self): + repo = Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, 'README.md')) + os.chdir(self.model_dir) + os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) + os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) + repo.push('test', all_files=True) + add1 = model_file_download(self.model_id, 'add1.py') + assert os.path.exists(add1) + add2 = model_file_download(self.model_id, 'add2.py') + assert os.path.exists(add2) + + def test_push_files(self): + repo = Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, 'README.md')) + os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) + os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) + os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) + repo.push('test', files=['add1.py', 'add2.py'], all_files=False) + add1 = model_file_download(self.model_id, 'add1.py') + assert os.path.exists(add1) + add2 = model_file_download(self.model_id, 'add2.py') + assert os.path.exists(add2) + with self.assertRaises(NotExistError) as cm: + model_file_download(self.model_id, 'add3.py') + print(cm.exception) + + +if __name__ == '__main__': + unittest.main()