diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index a149ede1..db76506e 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import re import subprocess from typing import List from xmlrpc.client import Boolean @@ -177,6 +178,15 @@ class GitCommandWrapper(metaclass=Singleton): cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] return self._run_git_command(*cmds) + def get_remote_branches(self, repo_dir: str): + cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] + rsp = self._run_git_command(*cmds) + info = [ + line.strip() + for line in rsp.stdout.decode('utf8').strip().split(os.linesep) + ][1:] + return ['/'.join(line.split('/')[1:]) for line in info] + def pull(self, repo_dir: str): cmds = ['-C', repo_dir, 'pull'] return self._run_git_command(*cmds) diff --git a/modelscope/hub/upload.py b/modelscope/hub/upload.py new file mode 100644 index 00000000..9dffc60e --- /dev/null +++ b/modelscope/hub/upload.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import datetime +import os +import shutil +import tempfile +import uuid +from typing import Dict, Optional +from uuid import uuid4 + +from filelock import FileLock + +from modelscope import __version__ +from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.errors import InvalidParameter, NotLoginException +from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.repository import Repository +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def upload_folder(model_id: str, + model_dir: str, + visibility: int = 0, + license: str = None, + chinese_name: Optional[str] = None, + commit_message: Optional[str] = None, + revision: Optional[str] = DEFAULT_MODEL_REVISION): + """ + Upload model from a given directory to given repository. A valid model directory + must contain a configuration.json file. + + This function upload the files in given directory to given repository. If the + given repository is not exists in remote, it will automatically create it with + given visibility, license and chinese_name parameters. If the revision is also + not exists in remote repository, it will create a new branch for it. + + This function must be called before calling HubApi's login with a valid token + which can be obtained from ModelScope's website. + + Args: + model_id (`str`): + The model id to be uploaded, caller must have write permission for it. + model_dir(`str`): + The Absolute Path of the finetune result. + visibility(`int`, defaults to `0`): + Visibility of the new created model(1-private, 5-public). If the model is + not exists in ModelScope, this function will create a new model with this + visibility and this parameter is required. You can ignore this parameter + if you make sure the model's existence. + license(`str`, defaults to `None`): + License of the new created model(see License). If the model is not exists + in ModelScope, this function will create a new model with this license + and this parameter is required. You can ignore this parameter if you + make sure the model's existence. + chinese_name(`str`, *optional*, defaults to `None`): + chinese name of the new created model. + commit_message(`str`, *optional*, defaults to `None`): + commit message of the push request. + revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): + which branch to push. If the branch is not exists, It will create a new + branch and push to it. + """ + if model_id is None: + raise InvalidParameter('model_id cannot be empty!') + if model_dir is None: + raise InvalidParameter('model_dir cannot be empty!') + if not os.path.exists(model_dir) or os.path.isfile(model_dir): + raise InvalidParameter('model_dir must be a valid directory.') + cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + if not os.path.exists(cfg_file): + raise ValueError(f'{model_dir} must contain a configuration.json.') + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise NotLoginException('Must login before upload!') + files_to_save = os.listdir(model_dir) + api = HubApi() + try: + api.get_model(model_id=model_id) + except Exception: + if visibility is None or license is None: + raise InvalidParameter( + 'visibility and license cannot be empty if want to create new repo' + ) + logger.info('Create new model %s' % model_id) + api.create_model( + model_id=model_id, + visibility=visibility, + license=license, + chinese_name=chinese_name) + tmp_dir = tempfile.mkdtemp() + git_wrapper = GitCommandWrapper() + try: + repo = Repository(model_dir=tmp_dir, clone_from=model_id) + branches = git_wrapper.get_remote_branches(tmp_dir) + if revision not in branches: + logger.info('Create new branch %s' % revision) + git_wrapper.new_branch(tmp_dir, revision) + git_wrapper.checkout(tmp_dir, revision) + for f in files_to_save: + if f[0] != '.': + src = os.path.join(model_dir, f) + if os.path.isdir(src): + shutil.copytree(src, os.path.join(tmp_dir, f)) + else: + shutil.copy(src, tmp_dir) + if not commit_message: + date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + commit_message = '[automsg] push model %s to hub at %s' % ( + model_id, date) + repo.push(commit_message=commit_message, branch=revision) + except Exception: + raise + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/tests/hub/test_hub_upload.py b/tests/hub/test_hub_upload.py new file mode 100644 index 00000000..d7e6e439 --- /dev/null +++ b/tests/hub/test_hub_upload.py @@ -0,0 +1,164 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.repository import Repository +from modelscope.hub.upload import upload_folder +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level +from .test_utils import TEST_ACCESS_TOKEN1, delete_credential + +logger = get_logger() + + +class HubUploadTest(unittest.TestCase): + + def setUp(self): + logger.info('SetUp') + self.api = HubApi() + self.user = os.environ.get('TEST_MODEL_ORG', 'citest') + logger.info(self.user) + self.create_model_name = '%s/%s' % (self.user, 'test_model_upload') + temporary_dir = tempfile.mkdtemp() + self.work_dir = temporary_dir + self.model_dir = os.path.join(temporary_dir, self.create_model_name) + self.finetune_path = os.path.join(self.work_dir, 'finetune_path') + self.repo_path = os.path.join(self.work_dir, 'repo_path') + os.mkdir(self.finetune_path) + os.system("echo '{}'>%s" + % os.path.join(self.finetune_path, ModelFile.CONFIGURATION)) + + def tearDown(self): + logger.info('TearDown') + shutil.rmtree(self.model_dir, ignore_errors=True) + self.api.delete_model(model_id=self.create_model_name) + + def test_upload_exits_repo_master(self): + logger.info('basic test for upload!') + self.api.login(TEST_ACCESS_TOKEN1) + self.api.create_model( + model_id=self.create_model_name, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + upload_folder( + model_id=self.create_model_name, model_dir=self.finetune_path) + Repository(model_dir=self.repo_path, clone_from=self.create_model_name) + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '222'>%s" + % os.path.join(self.finetune_path, 'add2.py')) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '333'>%s" + % os.path.join(self.finetune_path, 'add3.py')) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version2', + commit_message='add add3.py') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version2') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + assert os.path.exists(os.path.join(self.repo_path, 'add3.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + add4_path = os.path.join(self.finetune_path, 'temp') + os.mkdir(add4_path) + os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(add4_path, 'add4.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_non_exists_repo(self): + logger.info('test upload non exists repo!') + self.api.login(TEST_ACCESS_TOKEN1) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_model_new_revision', + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_model_new_revision') + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_without_token(self): + logger.info('test upload without login!') + self.api.login(TEST_ACCESS_TOKEN1) + delete_credential() + try: + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + except Exception as e: + logger.info(e) + self.api.login(TEST_ACCESS_TOKEN1) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + Repository( + model_dir=self.repo_path, clone_from=self.create_model_name) + assert os.path.exists( + os.path.join(self.repo_path, 'configuration.json')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_invalid_repo(self): + logger.info('test upload to invalid repo!') + self.api.login(TEST_ACCESS_TOKEN1) + try: + upload_folder( + model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + except Exception as e: + logger.info(e) + upload_folder( + model_id=self.create_model_name, + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + Repository( + model_dir=self.repo_path, clone_from=self.create_model_name) + assert os.path.exists( + os.path.join(self.repo_path, 'configuration.json')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + +if __name__ == '__main__': + unittest.main()