jiaqi.sjq 3 years ago
parent
commit
8fa385e27c
3 changed files with 291 additions and 0 deletions
  1. +10
    -0
      modelscope/hub/git.py
  2. +117
    -0
      modelscope/hub/upload.py
  3. +164
    -0
      tests/hub/test_hub_upload.py

+ 10
- 0
modelscope/hub/git.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import re
import subprocess
from typing import List
from xmlrpc.client import Boolean
@@ -177,6 +178,15 @@ class GitCommandWrapper(metaclass=Singleton):
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision]
return self._run_git_command(*cmds)

def get_remote_branches(self, repo_dir: str):
cmds = ['-C', '%s' % repo_dir, 'branch', '-r']
rsp = self._run_git_command(*cmds)
info = [
line.strip()
for line in rsp.stdout.decode('utf8').strip().split(os.linesep)
][1:]
return ['/'.join(line.split('/')[1:]) for line in info]

def pull(self, repo_dir: str):
cmds = ['-C', repo_dir, 'pull']
return self._run_git_command(*cmds)


+ 117
- 0
modelscope/hub/upload.py View File

@@ -0,0 +1,117 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import datetime
import os
import shutil
import tempfile
import uuid
from typing import Dict, Optional
from uuid import uuid4

from filelock import FileLock

from modelscope import __version__
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter, NotLoginException
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.logger import get_logger

logger = get_logger()


def upload_folder(model_id: str,
model_dir: str,
visibility: int = 0,
license: str = None,
chinese_name: Optional[str] = None,
commit_message: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION):
"""
Upload model from a given directory to given repository. A valid model directory
must contain a configuration.json file.

This function upload the files in given directory to given repository. If the
given repository is not exists in remote, it will automatically create it with
given visibility, license and chinese_name parameters. If the revision is also
not exists in remote repository, it will create a new branch for it.

This function must be called before calling HubApi's login with a valid token
which can be obtained from ModelScope's website.

Args:
model_id (`str`):
The model id to be uploaded, caller must have write permission for it.
model_dir(`str`):
The Absolute Path of the finetune result.
visibility(`int`, defaults to `0`):
Visibility of the new created model(1-private, 5-public). If the model is
not exists in ModelScope, this function will create a new model with this
visibility and this parameter is required. You can ignore this parameter
if you make sure the model's existence.
license(`str`, defaults to `None`):
License of the new created model(see License). If the model is not exists
in ModelScope, this function will create a new model with this license
and this parameter is required. You can ignore this parameter if you
make sure the model's existence.
chinese_name(`str`, *optional*, defaults to `None`):
chinese name of the new created model.
commit_message(`str`, *optional*, defaults to `None`):
commit message of the push request.
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
which branch to push. If the branch is not exists, It will create a new
branch and push to it.
"""
if model_id is None:
raise InvalidParameter('model_id cannot be empty!')
if model_dir is None:
raise InvalidParameter('model_dir cannot be empty!')
if not os.path.exists(model_dir) or os.path.isfile(model_dir):
raise InvalidParameter('model_dir must be a valid directory.')
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
if not os.path.exists(cfg_file):
raise ValueError(f'{model_dir} must contain a configuration.json.')
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise NotLoginException('Must login before upload!')
files_to_save = os.listdir(model_dir)
api = HubApi()
try:
api.get_model(model_id=model_id)
except Exception:
if visibility is None or license is None:
raise InvalidParameter(
'visibility and license cannot be empty if want to create new repo'
)
logger.info('Create new model %s' % model_id)
api.create_model(
model_id=model_id,
visibility=visibility,
license=license,
chinese_name=chinese_name)
tmp_dir = tempfile.mkdtemp()
git_wrapper = GitCommandWrapper()
try:
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
branches = git_wrapper.get_remote_branches(tmp_dir)
if revision not in branches:
logger.info('Create new branch %s' % revision)
git_wrapper.new_branch(tmp_dir, revision)
git_wrapper.checkout(tmp_dir, revision)
for f in files_to_save:
if f[0] != '.':
src = os.path.join(model_dir, f)
if os.path.isdir(src):
shutil.copytree(src, os.path.join(tmp_dir, f))
else:
shutil.copy(src, tmp_dir)
if not commit_message:
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
commit_message = '[automsg] push model %s to hub at %s' % (
model_id, date)
repo.push(commit_message=commit_message, branch=revision)
except Exception:
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

+ 164
- 0
tests/hub/test_hub_upload.py View File

@@ -0,0 +1,164 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.repository import Repository
from modelscope.hub.upload import upload_folder
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level
from .test_utils import TEST_ACCESS_TOKEN1, delete_credential

logger = get_logger()


class HubUploadTest(unittest.TestCase):

def setUp(self):
logger.info('SetUp')
self.api = HubApi()
self.user = os.environ.get('TEST_MODEL_ORG', 'citest')
logger.info(self.user)
self.create_model_name = '%s/%s' % (self.user, 'test_model_upload')
temporary_dir = tempfile.mkdtemp()
self.work_dir = temporary_dir
self.model_dir = os.path.join(temporary_dir, self.create_model_name)
self.finetune_path = os.path.join(self.work_dir, 'finetune_path')
self.repo_path = os.path.join(self.work_dir, 'repo_path')
os.mkdir(self.finetune_path)
os.system("echo '{}'>%s"
% os.path.join(self.finetune_path, ModelFile.CONFIGURATION))

def tearDown(self):
logger.info('TearDown')
shutil.rmtree(self.model_dir, ignore_errors=True)
self.api.delete_model(model_id=self.create_model_name)

def test_upload_exits_repo_master(self):
logger.info('basic test for upload!')
self.api.login(TEST_ACCESS_TOKEN1)
self.api.create_model(
model_id=self.create_model_name,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
os.system("echo '111'>%s"
% os.path.join(self.finetune_path, 'add1.py'))
upload_folder(
model_id=self.create_model_name, model_dir=self.finetune_path)
Repository(model_dir=self.repo_path, clone_from=self.create_model_name)
assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
os.system("echo '222'>%s"
% os.path.join(self.finetune_path, 'add2.py'))
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
os.system("echo '333'>%s"
% os.path.join(self.finetune_path, 'add3.py'))
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version2',
commit_message='add add3.py')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version2')
assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
assert os.path.exists(os.path.join(self.repo_path, 'add3.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
add4_path = os.path.join(self.finetune_path, 'temp')
os.mkdir(add4_path)
os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py'))
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert os.path.exists(os.path.join(add4_path, 'add4.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_non_exists_repo(self):
logger.info('test upload non exists repo!')
self.api.login(TEST_ACCESS_TOKEN1)
os.system("echo '111'>%s"
% os.path.join(self.finetune_path, 'add1.py'))
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_model_new_revision',
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_model_new_revision')
assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_without_token(self):
logger.info('test upload without login!')
self.api.login(TEST_ACCESS_TOKEN1)
delete_credential()
try:
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
except Exception as e:
logger.info(e)
self.api.login(TEST_ACCESS_TOKEN1)
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
Repository(
model_dir=self.repo_path, clone_from=self.create_model_name)
assert os.path.exists(
os.path.join(self.repo_path, 'configuration.json'))
shutil.rmtree(self.repo_path, ignore_errors=True)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_invalid_repo(self):
logger.info('test upload to invalid repo!')
self.api.login(TEST_ACCESS_TOKEN1)
try:
upload_folder(
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'),
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
except Exception as e:
logger.info(e)
upload_folder(
model_id=self.create_model_name,
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
Repository(
model_dir=self.repo_path, clone_from=self.create_model_name)
assert os.path.exists(
os.path.join(self.repo_path, 'configuration.json'))
shutil.rmtree(self.repo_path, ignore_errors=True)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save