| @@ -1,4 +1,3 @@ | |||||
| import imp | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import subprocess | import subprocess | ||||
| @@ -9,9 +8,10 @@ from typing import List, Optional, Tuple, Union | |||||
| import requests | import requests | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .constants import LOGGER_NAME | |||||
| from .constants import MODELSCOPE_URL_SCHEME | |||||
| from .errors import NotExistError, is_ok, raise_on_error | from .errors import NotExistError, is_ok, raise_on_error | ||||
| from .utils.utils import get_endpoint, model_id_to_group_owner_name | |||||
| from .utils.utils import (get_endpoint, get_gitlab_domain, | |||||
| model_id_to_group_owner_name) | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -40,9 +40,6 @@ class HubApi: | |||||
| <Tip> | <Tip> | ||||
| You only have to login once within 30 days. | You only have to login once within 30 days. | ||||
| </Tip> | </Tip> | ||||
| TODO: handle cookies expire | |||||
| """ | """ | ||||
| path = f'{self.endpoint}/api/v1/login' | path = f'{self.endpoint}/api/v1/login' | ||||
| r = requests.post( | r = requests.post( | ||||
| @@ -94,14 +91,14 @@ class HubApi: | |||||
| 'Path': owner_or_group, | 'Path': owner_or_group, | ||||
| 'Name': name, | 'Name': name, | ||||
| 'ChineseName': chinese_name, | 'ChineseName': chinese_name, | ||||
| 'Visibility': visibility, | |||||
| 'Visibility': visibility, # server check | |||||
| 'License': license | 'License': license | ||||
| }, | }, | ||||
| cookies=cookies) | cookies=cookies) | ||||
| r.raise_for_status() | r.raise_for_status() | ||||
| raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
| d = r.json() | |||||
| return d['Data']['Name'] | |||||
| model_repo_url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||||
| return model_repo_url | |||||
| def delete_model(self, model_id): | def delete_model(self, model_id): | ||||
| """_summary_ | """_summary_ | ||||
| @@ -209,25 +206,37 @@ class HubApi: | |||||
| class ModelScopeConfig: | class ModelScopeConfig: | ||||
| path_credential = expanduser('~/.modelscope/credentials') | path_credential = expanduser('~/.modelscope/credentials') | ||||
| os.makedirs(path_credential, exist_ok=True) | |||||
| @classmethod | |||||
| def make_sure_credential_path_exist(cls): | |||||
| os.makedirs(cls.path_credential, exist_ok=True) | |||||
| @classmethod | @classmethod | ||||
| def save_cookies(cls, cookies: CookieJar): | def save_cookies(cls, cookies: CookieJar): | ||||
| cls.make_sure_credential_path_exist() | |||||
| with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f: | ||||
| pickle.dump(cookies, f) | pickle.dump(cookies, f) | ||||
| @classmethod | @classmethod | ||||
| def get_cookies(cls): | def get_cookies(cls): | ||||
| try: | try: | ||||
| with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f: | |||||
| return pickle.load(f) | |||||
| cookies_path = os.path.join(cls.path_credential, 'cookies') | |||||
| with open(cookies_path, 'rb') as f: | |||||
| cookies = pickle.load(f) | |||||
| for cookie in cookies: | |||||
| if cookie.is_expired(): | |||||
| logger.warn('Auth is expored, please re-login') | |||||
| return None | |||||
| return cookies | |||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| logger.warn("Auth token does not exist, you'll get authentication \ | |||||
| error when downloading private model files. Please login first" | |||||
| ) | |||||
| logger.warn( | |||||
| "Auth token does not exist, you'll get authentication error when downloading \ | |||||
| private model files. Please login first") | |||||
| return None | |||||
| @classmethod | @classmethod | ||||
| def save_token(cls, token: str): | def save_token(cls, token: str): | ||||
| cls.make_sure_credential_path_exist() | |||||
| with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | with open(os.path.join(cls.path_credential, 'token'), 'w+') as f: | ||||
| f.write(token) | f.write(token) | ||||
| @@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo' | |||||
| MODEL_ID_SEPARATOR = '/' | MODEL_ID_SEPARATOR = '/' | ||||
| LOGGER_NAME = 'ModelScopeHub' | LOGGER_NAME = 'ModelScopeHub' | ||||
| class Licenses(object): | |||||
| APACHE_V2 = 'Apache License 2.0' | |||||
| GPL = 'GPL' | |||||
| LGPL = 'LGPL' | |||||
| MIT = 'MIT' | |||||
| class ModelVisibility(object): | |||||
| PRIVATE = 1 | |||||
| INTERNAL = 3 | |||||
| PUBLIC = 5 | |||||
| @@ -6,6 +6,10 @@ class RequestError(Exception): | |||||
| pass | pass | ||||
| class GitError(Exception): | |||||
| pass | |||||
| def is_ok(rsp): | def is_ok(rsp): | ||||
| """ Check the request is ok | """ Check the request is ok | ||||
| @@ -1,82 +1,161 @@ | |||||
| from threading import local | |||||
| from tkinter.messagebox import NO | |||||
| from typing import Union | |||||
| import subprocess | |||||
| from typing import List | |||||
| from xmlrpc.client import Boolean | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .constants import LOGGER_NAME | |||||
| from .utils._subprocess import run_subprocess | |||||
| from .errors import GitError | |||||
| logger = get_logger | |||||
| logger = get_logger() | |||||
| def git_clone( | |||||
| local_dir: str, | |||||
| repo_url: str, | |||||
| ): | |||||
| # TODO: use "git clone" or "git lfs clone" according to git version | |||||
| # TODO: print stderr when subprocess fails | |||||
| run_subprocess( | |||||
| f'git clone {repo_url}'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| class Singleton(type): | |||||
| _instances = {} | |||||
| def __call__(cls, *args, **kwargs): | |||||
| if cls not in cls._instances: | |||||
| cls._instances[cls] = super(Singleton, | |||||
| cls).__call__(*args, **kwargs) | |||||
| return cls._instances[cls] | |||||
| def git_checkout( | |||||
| local_dir: str, | |||||
| revsion: str, | |||||
| ): | |||||
| run_subprocess(f'git checkout {revsion}'.split(), local_dir) | |||||
| def git_add(local_dir: str, ): | |||||
| run_subprocess( | |||||
| 'git add .'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_commit(local_dir: str, commit_message: str): | |||||
| run_subprocess( | |||||
| 'git commit -v -m'.split() + [commit_message], | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_push(local_dir: str, branch: str): | |||||
| # check current branch | |||||
| cur_branch = git_current_branch(local_dir) | |||||
| if cur_branch != branch: | |||||
| logger.error( | |||||
| "You're trying to push to a different branch, please double check") | |||||
| return | |||||
| run_subprocess( | |||||
| f'git push origin {branch}'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| def git_current_branch(local_dir: str) -> Union[str, None]: | |||||
| """ | |||||
| Get current branch name | |||||
| Args: | |||||
| local_dir(`str`): local model repo directory | |||||
| Returns | |||||
| branch name you're currently on | |||||
| class GitCommandWrapper(metaclass=Singleton): | |||||
| """Some git operation wrapper | |||||
| """ | """ | ||||
| try: | |||||
| process = run_subprocess( | |||||
| 'git rev-parse --abbrev-ref HEAD'.split(), | |||||
| local_dir, | |||||
| True, | |||||
| ) | |||||
| return str(process.stdout).strip() | |||||
| except Exception as e: | |||||
| raise e | |||||
| default_git_path = 'git' # The default git command line | |||||
| def __init__(self, path: str = None): | |||||
| self.git_path = path or self.default_git_path | |||||
| def _run_git_command(self, *args) -> subprocess.CompletedProcess: | |||||
| """Run git command, if command return 0, return subprocess.response | |||||
| otherwise raise GitError, message is stdout and stderr. | |||||
| Raises: | |||||
| GitError: Exception with stdout and stderr. | |||||
| Returns: | |||||
| subprocess.CompletedProcess: the command response | |||||
| """ | |||||
| logger.info(' '.join(args)) | |||||
| response = subprocess.run( | |||||
| [self.git_path, *args], | |||||
| stdout=subprocess.PIPE, | |||||
| stderr=subprocess.PIPE) # compatible for python3.6 | |||||
| try: | |||||
| response.check_returncode() | |||||
| return response | |||||
| except subprocess.CalledProcessError as error: | |||||
| raise GitError( | |||||
| 'stdout: %s, stderr: %s' % | |||||
| (response.stdout.decode('utf8'), error.stderr.decode('utf8'))) | |||||
| def _add_token(self, token: str, url: str): | |||||
| if token: | |||||
| if '//oauth2' not in url: | |||||
| url = url.replace('//', '//oauth2:%s@' % token) | |||||
| return url | |||||
| def remove_token_from_url(self, url: str): | |||||
| if url and '//oauth2' in url: | |||||
| start_index = url.find('oauth2') | |||||
| end_index = url.find('@') | |||||
| url = url[:start_index] + url[end_index + 1:] | |||||
| return url | |||||
| def is_lfs_installed(self): | |||||
| cmd = ['lfs', 'env'] | |||||
| try: | |||||
| self._run_git_command(*cmd) | |||||
| return True | |||||
| except GitError: | |||||
| return False | |||||
| def clone(self, | |||||
| repo_base_dir: str, | |||||
| token: str, | |||||
| url: str, | |||||
| repo_name: str, | |||||
| branch: str = None): | |||||
| """ git clone command wrapper. | |||||
| For public project, token can None, private repo, there must token. | |||||
| Args: | |||||
| repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name | |||||
| token (str): The git token, must be provided for private project. | |||||
| url (str): The remote url | |||||
| repo_name (str): The local repository path name. | |||||
| branch (str, optional): _description_. Defaults to None. | |||||
| """ | |||||
| url = self._add_token(token, url) | |||||
| if branch: | |||||
| clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, | |||||
| repo_name, branch) | |||||
| else: | |||||
| clone_args = '-C %s clone %s' % (repo_base_dir, url) | |||||
| logger.debug(clone_args) | |||||
| clone_args = clone_args.split(' ') | |||||
| response = self._run_git_command(*clone_args) | |||||
| logger.info(response.stdout.decode('utf8')) | |||||
| return response | |||||
| def add(self, | |||||
| repo_dir: str, | |||||
| files: List[str] = list(), | |||||
| all_files: bool = False): | |||||
| if all_files: | |||||
| add_args = '-C %s add -A' % repo_dir | |||||
| elif len(files) > 0: | |||||
| files_str = ' '.join(files) | |||||
| add_args = '-C %s add %s' % (repo_dir, files_str) | |||||
| add_args = add_args.split(' ') | |||||
| rsp = self._run_git_command(*add_args) | |||||
| logger.info(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| def commit(self, repo_dir: str, message: str): | |||||
| """Run git commit command | |||||
| Args: | |||||
| message (str): commit message. | |||||
| """ | |||||
| commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] | |||||
| rsp = self._run_git_command(*commit_args) | |||||
| logger.info(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| def checkout(self, repo_dir: str, revision: str): | |||||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] | |||||
| return self._run_git_command(*cmds) | |||||
| def new_branch(self, repo_dir: str, revision: str): | |||||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||||
| return self._run_git_command(*cmds) | |||||
| def pull(self, repo_dir: str): | |||||
| cmds = ['-C', repo_dir, 'pull'] | |||||
| return self._run_git_command(*cmds) | |||||
| def push(self, | |||||
| repo_dir: str, | |||||
| token: str, | |||||
| url: str, | |||||
| local_branch: str, | |||||
| remote_branch: str, | |||||
| force: bool = False): | |||||
| url = self._add_token(token, url) | |||||
| push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, | |||||
| remote_branch) | |||||
| if force: | |||||
| push_args += ' -f' | |||||
| push_args = push_args.split(' ') | |||||
| rsp = self._run_git_command(*push_args) | |||||
| logger.info(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| def get_repo_remote_url(self, repo_dir: str): | |||||
| cmd_args = '-C %s config --get remote.origin.url' % repo_dir | |||||
| cmd_args = cmd_args.split(' ') | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| url = rsp.stdout.decode('utf8') | |||||
| return url.strip() | |||||
| @@ -1,173 +1,97 @@ | |||||
| import os | import os | ||||
| import subprocess | |||||
| from pathlib import Path | |||||
| from typing import Optional, Union | |||||
| from typing import List, Optional | |||||
| from modelscope.hub.errors import GitError | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | from .api import ModelScopeConfig | ||||
| from .constants import MODELSCOPE_URL_SCHEME | from .constants import MODELSCOPE_URL_SCHEME | ||||
| from .git import git_add, git_checkout, git_clone, git_commit, git_push | |||||
| from .utils._subprocess import run_subprocess | |||||
| from .git import GitCommandWrapper | |||||
| from .utils.utils import get_gitlab_domain | from .utils.utils import get_gitlab_domain | ||||
| logger = get_logger() | logger = get_logger() | ||||
| class Repository: | class Repository: | ||||
| """Representation local model git repository. | |||||
| """ | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| local_dir: str, | |||||
| clone_from: Optional[str] = None, | |||||
| auth_token: Optional[str] = None, | |||||
| private: Optional[bool] = False, | |||||
| model_dir: str, | |||||
| clone_from: str, | |||||
| revision: Optional[str] = 'master', | revision: Optional[str] = 'master', | ||||
| auth_token: Optional[str] = None, | |||||
| git_path: Optional[str] = None, | |||||
| ): | ): | ||||
| """ | """ | ||||
| Instantiate a Repository object by cloning the remote ModelScopeHub repo | Instantiate a Repository object by cloning the remote ModelScopeHub repo | ||||
| Args: | Args: | ||||
| local_dir(`str`): | |||||
| local directory to store the model files | |||||
| clone_from(`Optional[str] = None`): | |||||
| model_dir(`str`): | |||||
| The model root directory. | |||||
| clone_from: | |||||
| model id in ModelScope-hub from which git clone | model id in ModelScope-hub from which git clone | ||||
| You should ignore this parameter when `local_dir` is already a git repo | |||||
| auth_token(`Optional[str]`): | |||||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||||
| as the token is already saved when you login the first time | |||||
| private(`Optional[bool]`): | |||||
| whether the model is private, default to False | |||||
| revision(`Optional[str]`): | revision(`Optional[str]`): | ||||
| revision of the model you want to clone from. Can be any of a branch, tag or commit hash | revision of the model you want to clone from. Can be any of a branch, tag or commit hash | ||||
| auth_token(`Optional[str]`): | |||||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||||
| as the token is already saved when you login the first time, if None, we will use saved token. | |||||
| git_path:(`Optional[str]`): | |||||
| The git command line path, if None, we use 'git' | |||||
| """ | """ | ||||
| logger.info('Instantiating Repository object...') | |||||
| # Create local directory if not exist | |||||
| os.makedirs(local_dir, exist_ok=True) | |||||
| self.local_dir = os.path.join(os.getcwd(), local_dir) | |||||
| self.private = private | |||||
| # Check git and git-lfs installation | |||||
| self.check_git_versions() | |||||
| # Retrieve auth token | |||||
| if not private and isinstance(auth_token, str): | |||||
| logger.warning( | |||||
| 'cloning a public repo with a token, which will be ignored') | |||||
| self.token = None | |||||
| self.model_dir = model_dir | |||||
| self.model_base_dir = os.path.dirname(model_dir) | |||||
| self.model_repo_name = os.path.basename(model_dir) | |||||
| if auth_token: | |||||
| self.auth_token = auth_token | |||||
| else: | else: | ||||
| if isinstance(auth_token, str): | |||||
| self.token = auth_token | |||||
| else: | |||||
| self.token = ModelScopeConfig.get_token() | |||||
| if self.token is None: | |||||
| raise EnvironmentError( | |||||
| 'Token does not exist, the clone will fail for private repo.' | |||||
| 'Please login first.') | |||||
| # git clone | |||||
| if clone_from is not None: | |||||
| self.model_id = clone_from | |||||
| logger.info('cloning model repo to %s ...', self.local_dir) | |||||
| git_clone(self.local_dir, self.get_repo_url()) | |||||
| else: | |||||
| if is_git_repo(self.local_dir): | |||||
| logger.debug('[Repository] is a valid git repo') | |||||
| else: | |||||
| raise ValueError( | |||||
| 'If not specifying `clone_from`, you need to pass Repository a' | |||||
| ' valid git clone.') | |||||
| # git checkout | |||||
| if isinstance(revision, str) and revision != 'master': | |||||
| git_checkout(revision) | |||||
| def push_to_hub(self, | |||||
| commit_message: str, | |||||
| revision: Optional[str] = 'master'): | |||||
| """ | |||||
| Push changes changes to hub | |||||
| Args: | |||||
| commit_message(`str`): | |||||
| commit message describing the changes, it's mandatory | |||||
| revision(`Optional[str]`): | |||||
| remote branch you want to push to, default to `master` | |||||
| <Tip> | |||||
| The function complains when local and remote branch are different, please be careful | |||||
| </Tip> | |||||
| """ | |||||
| git_add(self.local_dir) | |||||
| git_commit(self.local_dir, commit_message) | |||||
| logger.info('Pushing changes to repo...') | |||||
| git_push(self.local_dir, revision) | |||||
| # TODO: if git push fails, how to retry? | |||||
| def check_git_versions(self): | |||||
| """ | |||||
| Checks that `git` and `git-lfs` can be run. | |||||
| Raises: | |||||
| `EnvironmentError`: if `git` or `git-lfs` are not installed. | |||||
| """ | |||||
| try: | |||||
| git_version = run_subprocess('git --version'.split(), | |||||
| self.local_dir).stdout.strip() | |||||
| except FileNotFoundError: | |||||
| raise EnvironmentError( | |||||
| 'Looks like you do not have git installed, please install.') | |||||
| self.auth_token = ModelScopeConfig.get_token() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| if not git_wrapper.is_lfs_installed(): | |||||
| logger.error('git lfs is not installed, please install.') | |||||
| self.git_wrapper = GitCommandWrapper(git_path) | |||||
| os.makedirs(self.model_dir, exist_ok=True) | |||||
| url = self._get_model_id_url(clone_from) | |||||
| if os.listdir(self.model_dir): # directory not empty. | |||||
| remote_url = self._get_remote_url() | |||||
| remote_url = self.git_wrapper.remove_token_from_url(remote_url) | |||||
| if remote_url and remote_url == url: # need not clone again | |||||
| return | |||||
| self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, | |||||
| self.model_repo_name, revision) | |||||
| def _get_model_id_url(self, model_id): | |||||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{model_id}' | |||||
| return url | |||||
| def _get_remote_url(self): | |||||
| try: | try: | ||||
| lfs_version = run_subprocess('git-lfs --version'.split(), | |||||
| self.local_dir).stdout.strip() | |||||
| except FileNotFoundError: | |||||
| raise EnvironmentError( | |||||
| 'Looks like you do not have git-lfs installed, please install.' | |||||
| ' You can install from https://git-lfs.github.com/.' | |||||
| ' Then run `git lfs install` (you only have to do this once).') | |||||
| logger.info(git_version + '\n' + lfs_version) | |||||
| def get_repo_url(self) -> str: | |||||
| """ | |||||
| Get repo url to clone, according whether the repo is private or not | |||||
| remote = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||||
| except GitError: | |||||
| remote = None | |||||
| return remote | |||||
| def push(self, | |||||
| commit_message: str, | |||||
| files: List[str] = list(), | |||||
| all_files: bool = False, | |||||
| branch: Optional[str] = 'master', | |||||
| force: bool = False): | |||||
| """Push local to remote, this method will do. | |||||
| git add | |||||
| git commit | |||||
| git push | |||||
| Args: | |||||
| commit_message (str): commit message | |||||
| revision (Optional[str], optional): which branch to push. Defaults to 'master'. | |||||
| """ | """ | ||||
| url = None | |||||
| if self.private: | |||||
| url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}' | |||||
| else: | |||||
| url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}' | |||||
| if not url: | |||||
| raise ValueError( | |||||
| 'Empty repo url, please check clone_from parameter') | |||||
| logger.debug('url to clone: %s', str(url)) | |||||
| return url | |||||
| def is_git_repo(folder: Union[str, Path]) -> bool: | |||||
| """ | |||||
| Check if the folder is the root or part of a git repository | |||||
| Args: | |||||
| folder (`str`): | |||||
| The folder in which to run the command. | |||||
| Returns: | |||||
| `bool`: `True` if the repository is part of a repository, `False` | |||||
| otherwise. | |||||
| """ | |||||
| folder_exists = os.path.exists(os.path.join(folder, '.git')) | |||||
| git_branch = subprocess.run( | |||||
| 'git branch'.split(), | |||||
| cwd=folder, | |||||
| stdout=subprocess.PIPE, | |||||
| stderr=subprocess.PIPE) | |||||
| return folder_exists and git_branch.returncode == 0 | |||||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||||
| self.git_wrapper.add(self.model_dir, files, all_files) | |||||
| self.git_wrapper.commit(self.model_dir, commit_message) | |||||
| self.git_wrapper.push( | |||||
| repo_dir=self.model_dir, | |||||
| token=self.auth_token, | |||||
| url=url, | |||||
| local_branch=branch, | |||||
| remote_branch=branch) | |||||
| @@ -1,40 +0,0 @@ | |||||
| import subprocess | |||||
| from typing import List | |||||
| def run_subprocess(command: List[str], | |||||
| folder: str, | |||||
| check=True, | |||||
| **kwargs) -> subprocess.CompletedProcess: | |||||
| """ | |||||
| Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, | |||||
| please call `subprocess.run` manually in case you would like for them not to | |||||
| be captured. | |||||
| Args: | |||||
| command (`List[str]`): | |||||
| The command to execute as a list of strings. | |||||
| folder (`str`): | |||||
| The folder in which to run the command. | |||||
| check (`bool`, *optional*, defaults to `True`): | |||||
| Setting `check` to `True` will raise a `subprocess.CalledProcessError` | |||||
| when the subprocess has a non-zero exit code. | |||||
| kwargs (`Dict[str]`): | |||||
| Keyword arguments to be passed to the `subprocess.run` underlying command. | |||||
| Returns: | |||||
| `subprocess.CompletedProcess`: The completed process. | |||||
| """ | |||||
| if isinstance(command, str): | |||||
| raise ValueError( | |||||
| '`run_subprocess` should be called with a list of strings.') | |||||
| return subprocess.run( | |||||
| command, | |||||
| stderr=subprocess.PIPE, | |||||
| stdout=subprocess.PIPE, | |||||
| check=check, | |||||
| encoding='utf-8', | |||||
| cwd=folder, | |||||
| **kwargs, | |||||
| ) | |||||
| @@ -2,7 +2,7 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, Union | |||||
| from typing import Dict, Optional, Union | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| @@ -42,13 +42,18 @@ class Model(ABC): | |||||
| return input | return input | ||||
| @classmethod | @classmethod | ||||
| def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||||
| """ Instantiate a model from local directory or remote model repo | |||||
| def from_pretrained(cls, | |||||
| model_name_or_path: str, | |||||
| revision: Optional[str] = 'master', | |||||
| *model_args, | |||||
| **kwargs): | |||||
| """ Instantiate a model from local directory or remote model repo. Note | |||||
| that when loading from remote, the model revision can be specified. | |||||
| """ | """ | ||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| local_model_dir = snapshot_download(model_name_or_path) | |||||
| local_model_dir = snapshot_download(model_name_or_path, revision) | |||||
| logger.info(f'initialize model from {local_model_dir}') | logger.info(f'initialize model from {local_model_dir}') | ||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | osp.join(local_model_dir, ModelFile.CONFIGURATION)) | ||||
| @@ -2,11 +2,11 @@ import os | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| from ....preprocessors.space.fields.intent_field import IntentBPETextField | from ....preprocessors.space.fields.intent_field import IntentBPETextField | ||||
| from ....trainers.nlp.space.trainers.intent_trainer import IntentTrainer | |||||
| from ....utils.config import Config | from ....utils.config import Config | ||||
| from ....utils.constant import Tasks | from ....utils.constant import Tasks | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| from .application.intent_app import IntentTrainer | |||||
| from .model.generator import Generator | from .model.generator import Generator | ||||
| from .model.model_base import ModelBase | from .model.model_base import ModelBase | ||||
| @@ -2,11 +2,11 @@ import os | |||||
| from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
| from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | ||||
| from ....trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer | |||||
| from ....utils.config import Config | from ....utils.config import Config | ||||
| from ....utils.constant import Tasks | from ....utils.constant import Tasks | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| from .application.gen_app import MultiWOZTrainer | |||||
| from .model.generator import Generator | from .model.generator import Generator | ||||
| from .model.model_base import ModelBase | from .model.model_base import ModelBase | ||||
| @@ -6,6 +6,7 @@ from typing import List, Optional, Union | |||||
| from requests import HTTPError | from requests import HTTPError | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| @@ -16,8 +17,8 @@ def create_model_if_not_exist( | |||||
| api, | api, | ||||
| model_id: str, | model_id: str, | ||||
| chinese_name: str, | chinese_name: str, | ||||
| visibility: Optional[int] = 5, # 1-private, 5-public | |||||
| license: Optional[str] = 'apache-2.0', | |||||
| visibility: Optional[int] = ModelVisibility.PUBLIC, | |||||
| license: Optional[str] = Licenses.APACHE_V2, | |||||
| revision: Optional[str] = 'master'): | revision: Optional[str] = 'master'): | ||||
| exists = True | exists = True | ||||
| try: | try: | ||||
| @@ -1,9 +1,9 @@ | |||||
| import unittest | import unittest | ||||
| from maas_hub.maas_api import MaasApi | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.utils.hub import create_model_if_not_exist | from modelscope.utils.hub import create_model_if_not_exist | ||||
| # note this is temporary before official account management is ready | |||||
| USER_NAME = 'maasadmin' | USER_NAME = 'maasadmin' | ||||
| PASSWORD = '12345678' | PASSWORD = '12345678' | ||||
| @@ -11,8 +11,7 @@ PASSWORD = '12345678' | |||||
| class HubExampleTest(unittest.TestCase): | class HubExampleTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| self.api = MaasApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api = HubApi() | |||||
| self.api.login(USER_NAME, PASSWORD) | self.api.login(USER_NAME, PASSWORD) | ||||
| @unittest.skip('to be used for local test only') | @unittest.skip('to be used for local test only') | ||||
| @@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase): | |||||
| model_chinese_name = '达摩卡通化模型' | model_chinese_name = '达摩卡通化模型' | ||||
| model_org = 'damo' | model_org = 'damo' | ||||
| model_id = '%s/%s' % (model_org, model_name) | model_id = '%s/%s' % (model_org, model_name) | ||||
| created = create_model_if_not_exist(self.api, model_id, | created = create_model_if_not_exist(self.api, model_id, | ||||
| model_chinese_name) | model_chinese_name) | ||||
| if not created: | if not created: | ||||
| @@ -1,14 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import subprocess | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| import uuid | import uuid | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | from modelscope.hub.api import HubApi, ModelScopeConfig | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.hub.utils.utils import get_gitlab_domain | |||||
| USER_NAME = 'maasadmin' | USER_NAME = 'maasadmin' | ||||
| PASSWORD = '12345678' | PASSWORD = '12345678' | ||||
| @@ -17,40 +17,7 @@ model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | model_org = 'unittest' | ||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| class GitError(Exception): | |||||
| pass | |||||
| # TODO make thest git operation to git library after merge code. | |||||
| def run_git_command(git_path, *args) -> subprocess.CompletedProcess: | |||||
| response = subprocess.run([git_path, *args], capture_output=True) | |||||
| try: | |||||
| response.check_returncode() | |||||
| return response.stdout.decode('utf8') | |||||
| except subprocess.CalledProcessError as error: | |||||
| raise GitError(error.stderr.decode('utf8')) | |||||
| # for public project, token can None, private repo, there must token. | |||||
| def clone(local_dir: str, token: str, url: str): | |||||
| url = url.replace('//', '//oauth2:%s@' % token) | |||||
| clone_args = '-C %s clone %s' % (local_dir, url) | |||||
| clone_args = clone_args.split(' ') | |||||
| stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args) | |||||
| print('stdout: %s' % stdout) | |||||
| def push(local_dir: str, token: str, url: str): | |||||
| url = url.replace('//', '//oauth2:%s@' % token) | |||||
| push_args = '-C %s push %s' % (local_dir, url) | |||||
| push_args = push_args.split(' ') | |||||
| stdout = run_git_command(DEFAULT_GIT_PATH, *push_args) | |||||
| print('stdout: %s' % stdout) | |||||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| download_model_file_name = 'test.bin' | |||||
| class HubOperationTest(unittest.TestCase): | class HubOperationTest(unittest.TestCase): | ||||
| @@ -65,8 +32,15 @@ class HubOperationTest(unittest.TestCase): | |||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| chinese_name=model_chinese_name, | chinese_name=model_chinese_name, | ||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| os.chdir(self.model_dir) | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, 'test.bin')) | |||||
| repo.push('add model', all_files=True) | |||||
| def tearDown(self): | def tearDown(self): | ||||
| os.chdir(self.old_cwd) | os.chdir(self.old_cwd) | ||||
| @@ -83,43 +57,10 @@ class HubOperationTest(unittest.TestCase): | |||||
| else: | else: | ||||
| raise | raise | ||||
| # Note that this can be done via git operation once model repo | |||||
| # has been created. Git-Op is the RECOMMENDED model upload approach | |||||
| def test_model_upload(self): | |||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| cmd_args = 'clone %s' % url | |||||
| cmd_args = cmd_args.split(' ') | |||||
| out = run_git_command('git', *cmd_args) | |||||
| print(out) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('touch file1') | |||||
| os.system('git add file1') | |||||
| os.system("git commit -m 'Test'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| def test_download_single_file(self): | def test_download_single_file(self): | ||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| os.system('git clone %s' % url) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('wget %s' % sample_model_url) | |||||
| os.system('git add .') | |||||
| os.system("git commit -m 'Add file'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| assert os.path.exists( | |||||
| os.path.join(temporary_dir, self.model_name, | |||||
| download_model_file_name)) | |||||
| downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
| model_id=self.model_id, file_path=download_model_file_name) | model_id=self.model_id, file_path=download_model_file_name) | ||||
| assert os.path.exists(downloaded_file) | |||||
| mdtime1 = os.path.getmtime(downloaded_file) | mdtime1 = os.path.getmtime(downloaded_file) | ||||
| # download again | # download again | ||||
| downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
| @@ -128,18 +69,6 @@ class HubOperationTest(unittest.TestCase): | |||||
| assert mdtime1 == mdtime2 | assert mdtime1 == mdtime2 | ||||
| def test_snapshot_download(self): | def test_snapshot_download(self): | ||||
| url = f'http://{get_gitlab_domain()}/{self.model_id}' | |||||
| print(url) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| os.chdir(temporary_dir) | |||||
| os.system('git clone %s' % url) | |||||
| repo_dir = os.path.join(temporary_dir, self.model_name) | |||||
| os.chdir(repo_dir) | |||||
| os.system('wget %s' % sample_model_url) | |||||
| os.system('git add .') | |||||
| os.system("git commit -m 'Add file'") | |||||
| token = ModelScopeConfig.get_token() | |||||
| push(repo_dir, token, url) | |||||
| snapshot_path = snapshot_download(model_id=self.model_id) | snapshot_path = snapshot_download(model_id=self.model_id) | ||||
| downloaded_file_path = os.path.join(snapshot_path, | downloaded_file_path = os.path.join(snapshot_path, | ||||
| download_model_file_name) | download_model_file_name) | ||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import tempfile | |||||
| import unittest | |||||
| import uuid | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.errors import GitError | |||||
| from modelscope.hub.repository import Repository | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| USER_NAME2 = 'sdkdev' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | |||||
| DEFAULT_GIT_PATH = 'git' | |||||
| sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx' | |||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| class HubPrivateRepositoryTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.old_cwd = os.getcwd() | |||||
| self.api = HubApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.token, _ = self.api.login(USER_NAME, PASSWORD) | |||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| chinese_name=model_chinese_name, | |||||
| visibility=1, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| def tearDown(self): | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| os.chdir(self.old_cwd) | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def test_clone_private_repo_no_permission(self): | |||||
| token, _ = self.api.login(USER_NAME2, PASSWORD) | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||||
| with self.assertRaises(GitError) as cm: | |||||
| Repository(local_dir, clone_from=self.model_id, auth_token=token) | |||||
| print(cm.exception) | |||||
| assert not os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
| def test_clone_private_repo_has_permission(self): | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||||
| repo1 = Repository( | |||||
| local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
| print(repo1.model_dir) | |||||
| assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
| def test_initlize_repo_multiple_times(self): | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| local_dir = os.path.join(temporary_dir, self.model_name) | |||||
| repo1 = Repository( | |||||
| local_dir, clone_from=self.model_id, auth_token=self.token) | |||||
| print(repo1.model_dir) | |||||
| assert os.path.exists(os.path.join(local_dir, 'README.md')) | |||||
| repo2 = Repository( | |||||
| local_dir, clone_from=self.model_id, | |||||
| auth_token=self.token) # skip clone | |||||
| print(repo2.model_dir) | |||||
| assert repo1.model_dir == repo2.model_dir | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import time | |||||
| import unittest | |||||
| import uuid | |||||
| from os.path import expanduser | |||||
| from requests import delete | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.errors import NotExistError | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| logger.setLevel('DEBUG') | |||||
| USER_NAME = 'maasadmin' | |||||
| PASSWORD = '12345678' | |||||
| model_chinese_name = '达摩卡通化模型' | |||||
| model_org = 'unittest' | |||||
| DEFAULT_GIT_PATH = 'git' | |||||
| download_model_file_name = 'mnist-12.onnx' | |||||
| def delete_credential(): | |||||
| path_credential = expanduser('~/.modelscope/credentials') | |||||
| shutil.rmtree(path_credential) | |||||
| def delete_stored_git_credential(user): | |||||
| credential_path = expanduser('~/.git-credentials') | |||||
| if os.path.exists(credential_path): | |||||
| with open(credential_path, 'r+') as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| if user in line: | |||||
| lines.remove(line) | |||||
| f.seek(0) | |||||
| f.write(''.join(lines)) | |||||
| f.truncate() | |||||
| class HubRepositoryTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = HubApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(USER_NAME, PASSWORD) | |||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_id = '%s/%s' % (model_org, self.model_name) | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| chinese_name=model_chinese_name, | |||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
| def tearDown(self): | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def test_clone_repo(self): | |||||
| Repository(self.model_dir, clone_from=self.model_id) | |||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
| def test_clone_public_model_without_token(self): | |||||
| delete_credential() | |||||
| delete_stored_git_credential(USER_NAME) | |||||
| Repository(self.model_dir, clone_from=self.model_id) | |||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
| self.api.login(USER_NAME, PASSWORD) # re-login for delete | |||||
| def test_push_all(self): | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
| os.chdir(self.model_dir) | |||||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||||
| repo.push('test', all_files=True) | |||||
| add1 = model_file_download(self.model_id, 'add1.py') | |||||
| assert os.path.exists(add1) | |||||
| add2 = model_file_download(self.model_id, 'add2.py') | |||||
| assert os.path.exists(add2) | |||||
| def test_push_files(self): | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| assert os.path.exists(os.path.join(self.model_dir, 'README.md')) | |||||
| os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) | |||||
| os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) | |||||
| os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py')) | |||||
| repo.push('test', files=['add1.py', 'add2.py'], all_files=False) | |||||
| add1 = model_file_download(self.model_id, 'add1.py') | |||||
| assert os.path.exists(add1) | |||||
| add2 = model_file_download(self.model_id, 'add2.py') | |||||
| assert os.path.exists(add2) | |||||
| with self.assertRaises(NotExistError) as cm: | |||||
| model_file_download(self.model_id, 'add3.py') | |||||
| print(cm.exception) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||