| @@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union | |||
| import requests | |||
| from modelscope.utils.logger import get_logger | |||
| from .constants import LOGGER_NAME | |||
| from .constants import MODELSCOPE_URL_SCHEME | |||
| from .errors import NotExistError, is_ok, raise_on_error | |||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||
| from .utils.utils import (get_endpoint, get_gitlab_domain, | |||
| model_id_to_group_owner_name) | |||
| logger = get_logger() | |||
| @@ -40,9 +41,6 @@ class HubApi: | |||
| <Tip> | |||
| You only have to login once within 30 days. | |||
| </Tip> | |||
| TODO: handle cookies expire | |||
| """ | |||
| path = f'{self.endpoint}/api/v1/login' | |||
| r = requests.post( | |||
| @@ -94,14 +92,14 @@ class HubApi: | |||
| 'Path': owner_or_group, | |||
| 'Name': name, | |||
| 'ChineseName': chinese_name, | |||
| 'Visibility': visibility, | |||
| 'Visibility': visibility, # server check | |||
| 'License': license | |||
| }, | |||
| cookies=cookies) | |||
| r.raise_for_status() | |||
| raise_on_error(r.json()) | |||
| d = r.json() | |||
| return d['Data']['Name'] | |||
| model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||
| return model_repo_url | |||
| def delete_model(self, model_id): | |||
| """_summary_ | |||
| @@ -209,25 +207,37 @@ class HubApi: | |||
| class ModelScopeConfig: | |||
| path_credential = expanduser('~/.modelscope/credentials') | |||
| os.makedirs(path_credential, exist_ok=True) | |||
| @classmethod | |||
| def make_sure_credential_path_exist(cls): | |||
| os.makedirs(cls.path_credential, exist_ok=True) | |||
| @classmethod | |||
| def save_cookies(cls, cookies: CookieJar): | |||
| cls.make_sure_credential_path_exist() | |||
| with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | |||
| pickle.dump(cookies, f) | |||
| @classmethod | |||
| def get_cookies(cls): | |||
| try: | |||
| with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||
| return pickle.load(f) | |||
| cookies_path = os.path.join(cls.path_credential, 'cookies') | |||
| with open(cookies_path, 'rb') as f: | |||
| cookies = pickle.load(f) | |||
| for cookie in cookies: | |||
| if cookie.is_expired(): | |||
| logger.warn('Auth is expored, please re-login') | |||
| return None | |||
| return cookies | |||
| except FileNotFoundError: | |||
| logger.warn("Auth token does not exist, you'll get authentication \ | |||
| error when downloading private model files. Please login first" | |||
| ) | |||
| logger.warn( | |||
| "Auth token does not exist, you'll get authentication error when downloading \ | |||
| private model files. Please login first") | |||
| return None | |||
| @classmethod | |||
| def save_token(cls, token: str): | |||
| cls.make_sure_credential_path_exist() | |||
| with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | |||
| f.write(token) | |||
| @@ -6,6 +6,10 @@ class RequestError(Exception): | |||
| pass | |||
| class GitError(Exception): | |||
| pass | |||
| def is_ok(rsp): | |||
| """ Check the request is ok | |||
| @@ -1,82 +1,161 @@ | |||
| from threading import local | |||
| from tkinter.messagebox import NO | |||
| from typing import Union | |||
| import subprocess | |||
| from typing import List | |||
| from xmlrpc.client import Boolean | |||
| from modelscope.utils.logger import get_logger | |||
| from .constants import LOGGER_NAME | |||
| from .utils._subprocess import run_subprocess | |||
| from .errors import GitError | |||
| logger = get_logger | |||
| logger = get_logger() | |||
| def git_clone( | |||
| local_dir: str, | |||
| repo_url: str, | |||
| ): | |||
| # TODO: use "git clone" or "git lfs clone" according to git version | |||
| # TODO: print stderr when subprocess fails | |||
| run_subprocess( | |||
| f'git clone {repo_url}'.split(), | |||
| local_dir, | |||
| True, | |||
| ) | |||
| class Singleton(type): | |||
| _instances = {} | |||
| def __call__(cls, *args, **kwargs): | |||
| if cls not in cls._instances: | |||
| cls._instances[cls] = super(Singleton, | |||
| cls).__call__(*args, **kwargs) | |||
| return cls._instances[cls] | |||
| def git_checkout( | |||
| local_dir: str, | |||
| revsion: str, | |||
| ): | |||
| run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||
| def git_add(local_dir: str, ): | |||
| run_subprocess( | |||
| 'git add .'.split(), | |||
| local_dir, | |||
| True, | |||
| ) | |||
| def git_commit(local_dir: str, commit_message: str): | |||
| run_subprocess( | |||
| 'git commit -v -m'.split() + [commit_message], | |||
| local_dir, | |||
| True, | |||
| ) | |||
| def git_push(local_dir: str, branch: str): | |||
| # check current branch | |||
| cur_branch = git_current_branch(local_dir) | |||
| if cur_branch != branch: | |||
| logger.error( | |||
| "You're trying to push to a different branch, please double check") | |||
| return | |||
| run_subprocess( | |||
| f'git push origin {branch}'.split(), | |||
| local_dir, | |||
| True, | |||
| ) | |||
| def git_current_branch(local_dir: str) -> Union[str, None]: | |||
| """ | |||
| Get current branch name | |||
| Args: | |||
| local_dir(`str`): local model repo directory | |||
| Returns | |||
| branch name you're currently on | |||
| class GitCommandWrapper(metaclass=Singleton): | |||
| """Some git operation wrapper | |||
| """ | |||
| try: | |||
| process = run_subprocess( | |||
| 'git rev-parse --abbrev-ref HEAD'.split(), | |||
| local_dir, | |||
| True, | |||
| ) | |||
| return str(process.stdout).strip() | |||
| except Exception as e: | |||
| raise e | |||
| default_git_path = 'git' # The default git command line | |||
| def __init__(self, path: str = None): | |||
| self.git_path = path or self.default_git_path | |||
| def _run_git_command(self, *args) -> subprocess.CompletedProcess: | |||
| """Run git command, if command return 0, return subprocess.response | |||
| otherwise raise GitError, message is stdout and stderr. | |||
| Raises: | |||
| GitError: Exception with stdout and stderr. | |||
| Returns: | |||
| subprocess.CompletedProcess: the command response | |||
| """ | |||
| logger.info(' '.join(args)) | |||
| response = subprocess.run( | |||
| [self.git_path, *args], | |||
| stdout=subprocess.PIPE, | |||
| stderr=subprocess.PIPE) # compatible for python3.6 | |||
| try: | |||
| response.check_returncode() | |||
| return response | |||
| except subprocess.CalledProcessError as error: | |||
| raise GitError( | |||
| 'stdout: %s, stderr: %s' % | |||
| (response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | |||
| def _add_token(self, token: str, url: str): | |||
| if token: | |||
| if '//oauth2' not in url: | |||
| url = url.replace('//', '//oauth2:%s@' % token) | |||
| return url | |||
| def remove_token_from_url(self, url: str): | |||
| if url and '//oauth2' in url: | |||
| start_index = url.find('oauth2') | |||
| end_index = url.find('@') | |||
| url = url[:start_index] + url[end_index + 1:] | |||
| return url | |||
| def is_lfs_installed(self): | |||
| cmd = ['lfs', 'env'] | |||
| try: | |||
| self._run_git_command(*cmd) | |||
| return True | |||
| except GitError: | |||
| return False | |||
| def clone(self, | |||
| repo_base_dir: str, | |||
| token: str, | |||
| url: str, | |||
| repo_name: str, | |||
| branch: str = None): | |||
| """ git clone command wrapper. | |||
| For public project, token can None, private repo, there must token. | |||
| Args: | |||
| repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name | |||
| token (str): The git token, must be provided for private project. | |||
| url (str): The remote url | |||
| repo_name (str): The local repository path name. | |||
| branch (str, optional): _description_. Defaults to None. | |||
| """ | |||
| url = self._add_token(token, url) | |||
| if branch: | |||
| clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, | |||
| repo_name, branch) | |||
| else: | |||
| clone_args = '-C %s clone %s' % (repo_base_dir, url) | |||
| logger.debug(clone_args) | |||
| clone_args = clone_args.split(' ') | |||
| response = self._run_git_command(*clone_args) | |||
| logger.info(response.stdout.decode('utf8')) | |||
| return response | |||
| def add(self, | |||
| repo_dir: str, | |||
| files: List[str] = list(), | |||
| all_files: bool = False): | |||
| if all_files: | |||
| add_args = '-C %s add -A' % repo_dir | |||
| elif len(files) > 0: | |||
| files_str = ' '.join(files) | |||
| add_args = '-C %s add %s' % (repo_dir, files_str) | |||
| add_args = add_args.split(' ') | |||
| rsp = self._run_git_command(*add_args) | |||
| logger.info(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| def commit(self, repo_dir: str, message: str): | |||
| """Run git commit command | |||
| Args: | |||
| message (str): commit message. | |||
| """ | |||
| commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] | |||
| rsp = self._run_git_command(*commit_args) | |||
| logger.info(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| def checkout(self, repo_dir: str, revision: str): | |||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] | |||
| return self._run_git_command(*cmds) | |||
| def new_branch(self, repo_dir: str, revision: str): | |||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||
| return self._run_git_command(*cmds) | |||
| def pull(self, repo_dir: str): | |||
| cmds = ['-C', repo_dir, 'pull'] | |||
| return self._run_git_command(*cmds) | |||
| def push(self, | |||
| repo_dir: str, | |||
| token: str, | |||
| url: str, | |||
| local_branch: str, | |||
| remote_branch: str, | |||
| force: bool = False): | |||
| url = self._add_token(token, url) | |||
| push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, | |||
| remote_branch) | |||
| if force: | |||
| push_args += ' -f' | |||
| push_args = push_args.split(' ') | |||
| rsp = self._run_git_command(*push_args) | |||
| logger.info(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| def get_repo_remote_url(self, repo_dir: str): | |||
| cmd_args = '-C %s config --get remote.origin.url' % repo_dir | |||
| cmd_args = cmd_args.split(' ') | |||
| rsp = self._run_git_command(*cmd_args) | |||
| url = rsp.stdout.decode('utf8') | |||
| return url.strip() | |||
| @@ -1,173 +1,97 @@ | |||
| import os | |||
| import subprocess | |||
| from pathlib import Path | |||
| from typing import Optional, Union | |||
| from typing import List, Optional | |||
| from modelscope.hub.errors import GitError | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from .constants import MODELSCOPE_URL_SCHEME | |||
| from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||
| from .utils._subprocess import run_subprocess | |||
| from .git import GitCommandWrapper | |||
| from .utils.utils import get_gitlab_domain | |||
| logger = get_logger() | |||
| class Repository: | |||
| """Representation local model git repository. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| local_dir: str, | |||
| clone_from: Optional[str] = None, | |||
| auth_token: Optional[str] = None, | |||
| private: Optional[bool] = False, | |||
| model_dir: str, | |||
| clone_from: str, | |||
| revision: Optional[str] = 'master', | |||
| auth_token: Optional[str] = None, | |||
| git_path: Optional[str] = None, | |||
| ): | |||
| """ | |||
| Instantiate a Repository object by cloning the remote ModelScopeHub repo | |||
| Args: | |||
| local_dir(`str`): | |||
| local directory to store the model files | |||
| clone_from(`Optional[str] = None`): | |||
| model_dir(`str`): | |||
| The model root directory. | |||
| clone_from: | |||
| model id in ModelScope-hub from which git clone | |||
| You should ignore this parameter when `local_dir` is already a git repo | |||
| auth_token(`Optional[str]`): | |||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
| as the token is already saved when you login the first time | |||
| private(`Optional[bool]`): | |||
| whether the model is private, default to False | |||
| revision(`Optional[str]`): | |||
| revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||
| auth_token(`Optional[str]`): | |||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
| as the token is already saved when you login the first time, if None, we will use saved token. | |||
| git_path:(`Optional[str]`): | |||
| The git command line path, if None, we use 'git' | |||
| """ | |||
| logger.info('Instantiating Repository object...') | |||
| # Create local directory if not exist | |||
| os.makedirs(local_dir, exist_ok=True) | |||
| self.local_dir = os.path.join(os.getcwd(), local_dir) | |||
| self.private = private | |||
| # Check git and git-lfs installation | |||
| self.check_git_versions() | |||
| # Retrieve auth token | |||
| if not private and isinstance(auth_token, str): | |||
| logger.warning( | |||
| 'cloning a public repo with a token, which will be ignored') | |||
| self.token = None | |||
| self.model_dir = model_dir | |||
| self.model_base_dir = os.path.dirname(model_dir) | |||
| self.model_repo_name = os.path.basename(model_dir) | |||
| if auth_token: | |||
| self.auth_token = auth_token | |||
| else: | |||
| if isinstance(auth_token, str): | |||
| self.token = auth_token | |||
| else: | |||
| self.token = ModelScopeConfig.get_token() | |||
| if self.token is None: | |||
| raise EnvironmentError( | |||
| 'Token does not exist, the clone will fail for private repo.' | |||
| 'Please login first.') | |||
| # git clone | |||
| if clone_from is not None: | |||
| self.model_id = clone_from | |||
| logger.info('cloning model repo to %s ...', self.local_dir) | |||
| git_clone(self.local_dir, self.get_repo_url()) | |||
| else: | |||
| if is_git_repo(self.local_dir): | |||
| logger.debug('[Repository] is a valid git repo') | |||
| else: | |||
| raise ValueError( | |||
| 'If not specifying `clone_from`, you need to pass Repository a' | |||
| ' valid git clone.') | |||
| # git checkout | |||
| if isinstance(revision, str) and revision != 'master': | |||
| git_checkout(revision) | |||
| def push_to_hub(self, | |||
| commit_message: str, | |||
| revision: Optional[str] = 'master'): | |||
| """ | |||
| Push changes changes to hub | |||
| Args: | |||
| commit_message(`str`): | |||
| commit message describing the changes, it's mandatory | |||
| revision(`Optional[str]`): | |||
| remote branch you want to push to, default to `master` | |||
| <Tip> | |||
| The function complains when local and remote branch are different, please be careful | |||
| </Tip> | |||
| """ | |||
| git_add(self.local_dir) | |||
| git_commit(self.local_dir, commit_message) | |||
| logger.info('Pushing changes to repo...') | |||
| git_push(self.local_dir, revision) | |||
| # TODO: if git push fails, how to retry? | |||
| def check_git_versions(self): | |||
| """ | |||
| Checks that `git` and `git-lfs` can be run. | |||
| Raises: | |||
| `EnvironmentError`: if `git` or `git-lfs` are not installed. | |||
| """ | |||
| try: | |||
| git_version = run_subprocess('git --version'.split(), | |||
| self.local_dir).stdout.strip() | |||
| except FileNotFoundError: | |||
| raise EnvironmentError( | |||
| 'Looks like you do not have git installed, please install.') | |||
| self.auth_token = ModelScopeConfig.get_token() | |||
| git_wrapper = GitCommandWrapper() | |||
| if not git_wrapper.is_lfs_installed(): | |||
| logger.error('git lfs is not installed, please install.') | |||
| self.git_wrapper = GitCommandWrapper(git_path) | |||
| os.makedirs(self.model_dir, exist_ok=True) | |||
| url = self._get_model_id_url(clone_from) | |||
| if os.listdir(self.model_dir): # directory not empty. | |||
| remote_url = self._get_remote_url() | |||
| remote_url = self.git_wrapper.remove_token_from_url(remote_url) | |||
| if remote_url and remote_url == url: # need not clone again | |||
| return | |||
| self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | |||
| self.model_repo_name, revision) | |||
| def _get_model_id_url(self, model_id): | |||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||
| return url | |||
| def _get_remote_url(self): | |||
| try: | |||
| lfs_version = run_subprocess('git-lfs --version'.split(), | |||
| self.local_dir).stdout.strip() | |||
| except FileNotFoundError: | |||
| raise EnvironmentError( | |||
| 'Looks like you do not have git-lfs installed, please install.' | |||
| ' You can install from https://git-lfs.github.com/.' | |||
| ' Then run `git lfs install` (you only have to do this once).') | |||
| logger.info(git_version + '\n' + lfs_version) | |||
| def get_repo_url(self) -> str: | |||
| """ | |||
| Get repo url to clone, according whether the repo is private or not | |||
| remote = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
| except GitError: | |||
| remote = None | |||
| return remote | |||
| def push(self, | |||
| commit_message: str, | |||
| files: List[str] = list(), | |||
| all_files: bool = False, | |||
| branch: Optional[str] = 'master', | |||
| force: bool = False): | |||
| """Push local to remote, this method will do. | |||
| git add | |||
| git commit | |||
| git push | |||
| Args: | |||
| commit_message (str): commit message | |||
| revision (Optional[str], optional): which branch to push. Defaults to 'master'. | |||
| """ | |||
| url = None | |||
| if self.private: | |||
| url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||
| else: | |||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||
| if not url: | |||
| raise ValueError( | |||
| 'Empty repo url, please check clone_from parameter') | |||
| logger.debug('url to clone: %s', str(url)) | |||
| return url | |||
| def is_git_repo(folder: Union[str, Path]) -> bool: | |||
| """ | |||
| Check if the folder is the root or part of a git repository | |||
| Args: | |||
| folder (`str`): | |||
| The folder in which to run the command. | |||
| Returns: | |||
| `bool`: `True` if the repository is part of a repository, `False` | |||
| otherwise. | |||
| """ | |||
| folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||
| git_branch = subprocess.run( | |||
| 'git branch'.split(), | |||
| cwd=folder, | |||
| stdout=subprocess.PIPE, | |||
| stderr=subprocess.PIPE) | |||
| return folder_exists and git_branch.returncode == 0 | |||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
| self.git_wrapper.add(self.model_dir, files, all_files) | |||
| self.git_wrapper.commit(self.model_dir, commit_message) | |||
| self.git_wrapper.push( | |||
| repo_dir=self.model_dir, | |||
| token=self.auth_token, | |||
| url=url, | |||
| local_branch=branch, | |||
| remote_branch=branch) | |||
| @@ -1,40 +0,0 @@ | |||
| import subprocess | |||
| from typing import List | |||
| def run_subprocess(command: List[str], | |||
| folder: str, | |||
| check=True, | |||
| **kwargs) -> subprocess.CompletedProcess: | |||
| """ | |||
| Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||
| please call `subprocess.run` manually in case you would like for them not to | |||
| be captured. | |||
| Args: | |||
| command (`List[str]`): | |||
| The command to execute as a list of strings. | |||
| folder (`str`): | |||
| The folder in which to run the command. | |||
| check (`bool`, *optional*, defaults to `True`): | |||
| Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||
| when the subprocess has a non-zero exit code. | |||
| kwargs (`Dict[str]`): | |||
| Keyword arguments to be passed to the `subprocess.run` underlying command. | |||
| Returns: | |||
| `subprocess.CompletedProcess`: The completed process. | |||
| """ | |||
| if isinstance(command, str): | |||
| raise ValueError( | |||
| '`run_subprocess` should be called with a list of strings.') | |||
| return subprocess.run( | |||
| command, | |||
| stderr=subprocess.PIPE, | |||
| stdout=subprocess.PIPE, | |||
| check=check, | |||
| encoding='utf-8', | |||
| cwd=folder, | |||
| **kwargs, | |||
| ) | |||
| @@ -1,14 +1,13 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import subprocess | |||
| import tempfile | |||
| import unittest | |||
| import uuid | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.hub.utils.utils import get_gitlab_domain | |||
| USER_NAME = 'maasadmin' | |||
| PASSWORD = '12345678' | |||
| @@ -17,40 +16,7 @@ model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'unittest' | |||
| DEFAULT_GIT_PATH = 'git' | |||
| class GitError(Exception): | |||
| pass | |||
| # TODO make thest git operation to git library after merge code. | |||
| def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||
| response = subprocess.run([git_path, *args], capture_output=True) | |||
| try: | |||
| response.check_returncode() | |||
| return response.stdout.decode('utf8') | |||
| except subprocess.CalledProcessError as error: | |||
| raise GitError(error.stderr.decode('utf8')) | |||
| # for public project, token can None, private repo, there must token. | |||
| def clone(local_dir: str, token: str, url: str): | |||
| url = url.replace('//', '//oauth2:%s@' % token) | |||
| clone_args = '-C %s clone %s' % (local_dir, url) | |||
| clone_args = clone_args.split(' ') | |||
| stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||
| print('stdout: %s' % stdout) | |||
| def push(local_dir: str, token: str, url: str): | |||
| url = url.replace('//', '//oauth2:%s@' % token) | |||
| push_args = '-C %s push %s' % (local_dir, url) | |||
| push_args = push_args.split(' ') | |||
| stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||
| print('stdout: %s' % stdout) | |||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
| download_model_file_name = 'mnist-12.onnx' | |||
| download_model_file_name = 'test.bin' | |||
| class HubOperationTest(unittest.TestCase): | |||
| @@ -67,6 +33,13 @@ class HubOperationTest(unittest.TestCase): | |||
| chinese_name=model_chinese_name, | |||
| visibility=5, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| os.chdir(self.model_dir) | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, 'test.bin')) | |||
| repo.push('add model', all_files=True) | |||
| def tearDown(self): | |||
| os.chdir(self.old_cwd) | |||
| @@ -83,43 +56,10 @@ class HubOperationTest(unittest.TestCase): | |||
| else: | |||
| raise | |||
| # Note that this can be done via git operation once model repo | |||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||
| def test_model_upload(self): | |||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
| print(url) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| os.chdir(temporary_dir) | |||
| cmd_args = 'clone %s' % url | |||
| cmd_args = cmd_args.split(' ') | |||
| out = run_git_command('git', *cmd_args) | |||
| print(out) | |||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||
| os.chdir(repo_dir) | |||
| os.system('touch file1') | |||
| os.system('git add file1') | |||
| os.system("git commit -m 'Test'") | |||
| token = ModelScopeConfig.get_token() | |||
| push(repo_dir, token, url) | |||
| def test_download_single_file(self): | |||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
| print(url) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| os.chdir(temporary_dir) | |||
| os.system('git clone %s' % url) | |||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||
| os.chdir(repo_dir) | |||
| os.system('wget %s' % sample_model_url) | |||
| os.system('git add .') | |||
| os.system("git commit -m 'Add file'") | |||
| token = ModelScopeConfig.get_token() | |||
| push(repo_dir, token, url) | |||
| assert os.path.exists( | |||
| os.path.join(temporary_dir, self.model_name, | |||
| download_model_file_name)) | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, file_path=download_model_file_name) | |||
| assert os.path.exists(downloaded_file) | |||
| mdtime1 = os.path.getmtime(downloaded_file) | |||
| # download again | |||
| downloaded_file = model_file_download( | |||
| @@ -128,18 +68,6 @@ class HubOperationTest(unittest.TestCase): | |||
| assert mdtime1 == mdtime2 | |||
| def test_snapshot_download(self): | |||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||
| print(url) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| os.chdir(temporary_dir) | |||
| os.system('git clone %s' % url) | |||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||
| os.chdir(repo_dir) | |||
| os.system('wget %s' % sample_model_url) | |||
| os.system('git add .') | |||
| os.system("git commit -m 'Add file'") | |||
| token = ModelScopeConfig.get_token() | |||
| push(repo_dir, token, url) | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import uuid | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.errors import GitError | |||
| from modelscope.hub.repository import Repository | |||
| USER_NAME = 'maasadmin' | |||
| PASSWORD = '12345678' | |||
| USER_NAME2 = 'sdkdev' | |||
| model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'unittest' | |||
| DEFAULT_GIT_PATH = 'git' | |||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||
| download_model_file_name = 'mnist-12.onnx' | |||
| class HubPrivateRepositoryTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| chinese_name=model_chinese_name, | |||
| visibility=1, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| def tearDown(self): | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| os.chdir(self.old_cwd) | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def test_clone_private_repo_no_permission(self): | |||
| token, _ = self.api.login(USER_NAME2, PASSWORD) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||
| with self.assertRaises(GitError) as cm: | |||
| Repository(local_dir, clone_from=self.model_id, auth_token=token) | |||
| print(cm.exception) | |||
| assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||
| def test_clone_private_repo_has_permission(self): | |||
| temporary_dir = tempfile.mkdtemp() | |||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||
| repo1 = Repository( | |||
| local_dir, clone_from=self.model_id, auth_token=self.token) | |||
| print(repo1.model_dir) | |||
| assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
| def test_initlize_repo_multiple_times(self): | |||
| temporary_dir = tempfile.mkdtemp() | |||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||
| repo1 = Repository( | |||
| local_dir, clone_from=self.model_id, auth_token=self.token) | |||
| print(repo1.model_dir) | |||
| assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||
| repo2 = Repository( | |||
| local_dir, clone_from=self.model_id, | |||
| auth_token=self.token) # skip clone | |||
| print(repo2.model_dir) | |||
| assert repo1.model_dir == repo2.model_dir | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import time | |||
| import unittest | |||
| import uuid | |||
| from os.path import expanduser | |||
| from requests import delete | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.errors import NotExistError | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| logger.setLevel('DEBUG') | |||
| USER_NAME = 'maasadmin' | |||
| PASSWORD = '12345678' | |||
| model_chinese_name = '达摩卡通化模型' | |||
| model_org = 'unittest' | |||
| DEFAULT_GIT_PATH = 'git' | |||
| download_model_file_name = 'mnist-12.onnx' | |||
| def delete_credential(): | |||
| path_credential = expanduser('~/.modelscope/credentials') | |||
| shutil.rmtree(path_credential) | |||
| def delete_stored_git_credential(user): | |||
| credential_path = expanduser('~/.git-credentials') | |||
| if os.path.exists(credential_path): | |||
| with open(credential_path, 'r+') as f: | |||
| lines = f.readlines() | |||
| for line in lines: | |||
| if user in line: | |||
| lines.remove(line) | |||
| f.seek(0) | |||
| f.write(''.join(lines)) | |||
| f.truncate() | |||
| class HubRepositoryTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.api.login(USER_NAME, PASSWORD) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| chinese_name=model_chinese_name, | |||
| visibility=5, # 1-private, 5-public | |||
| license='apache-2.0') | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| def tearDown(self): | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def test_clone_repo(self): | |||
| Repository(self.model_dir, clone_from=self.model_id) | |||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
| def test_clone_public_model_without_token(self): | |||
| delete_credential() | |||
| delete_stored_git_credential(USER_NAME) | |||
| Repository(self.model_dir, clone_from=self.model_id) | |||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
| self.api.login(USER_NAME, PASSWORD) # re-login for delete | |||
| def test_push_all(self): | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
| os.chdir(self.model_dir) | |||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
| repo.push('test', all_files=True) | |||
| add1 = model_file_download(self.model_id, 'add1.py') | |||
| assert os.path.exists(add1) | |||
| add2 = model_file_download(self.model_id, 'add2.py') | |||
| assert os.path.exists(add2) | |||
| def test_push_files(self): | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||
| os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||
| repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||
| add1 = model_file_download(self.model_id, 'add1.py') | |||
| assert os.path.exists(add1) | |||
| add2 = model_file_download(self.model_id, 'add2.py') | |||
| assert os.path.exists(add2) | |||
| with self.assertRaises(NotExistError) as cm: | |||
| model_file_download(self.model_id, 'add3.py') | |||
| print(cm.exception) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||