Browse Source

* [to #45486649]feat: modelscope model version use model repo tag, unsupport branch or commit it, client user-agent header unified

master
mulin.lyh 3 years ago
parent
commit
384377b8f5
21 changed files with 825 additions and 210 deletions
  1. +22
    -16
      .dev_scripts/ci_container_test.sh
  2. +46
    -22
      .dev_scripts/dockerci.sh
  3. +2
    -2
      modelscope/__init__.py
  4. +155
    -35
      modelscope/hub/api.py
  5. +3
    -0
      modelscope/hub/constants.py
  6. +12
    -14
      modelscope/hub/deploy.py
  7. +34
    -0
      modelscope/hub/errors.py
  8. +38
    -66
      modelscope/hub/file_download.py
  9. +20
    -0
      modelscope/hub/git.py
  10. +41
    -5
      modelscope/hub/repository.py
  11. +13
    -13
      modelscope/hub/snapshot_download.py
  12. +14
    -1
      modelscope/hub/utils/utils.py
  13. +5
    -1
      modelscope/msdatasets/ms_dataset.py
  14. +3
    -1
      modelscope/utils/constant.py
  15. +4
    -0
      modelscope/version.py
  16. +25
    -11
      tests/hub/test_hub_operation.py
  17. +44
    -15
      tests/hub/test_hub_private_files.py
  18. +2
    -3
      tests/hub/test_hub_private_repository.py
  19. +7
    -5
      tests/hub/test_hub_repository.py
  20. +145
    -0
      tests/hub/test_hub_revision.py
  21. +190
    -0
      tests/hub/test_hub_revision_release_mode.py

+ 22
- 16
.dev_scripts/ci_container_test.sh View File

@@ -1,22 +1,28 @@
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/tests.txt
echo "Testing envs"
printenv
echo "ENV END"
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/tests.txt


git config --global --add safe.directory /Maas-lib
git config --global --add safe.directory /Maas-lib


# linter test
# use internal project for pre-commit due to the network problem
pre-commit run -c .pre-commit-config_local.yaml --all-files
if [ $? -ne 0 ]; then
echo "linter test failed, please run 'pre-commit run --all-files' to check"
exit -1
# linter test
# use internal project for pre-commit due to the network problem
pre-commit run -c .pre-commit-config_local.yaml --all-files
if [ $? -ne 0 ]; then
echo "linter test failed, please run 'pre-commit run --all-files' to check"
exit -1
fi
# test with install
python setup.py install
else
echo "Running case in release image, run case directly!"
fi fi
# test with install
python setup.py install

if [ $# -eq 0 ]; then if [ $# -eq 0 ]; then
ci_command="python tests/run.py --subprocess" ci_command="python tests/run.py --subprocess"
else else


+ 46
- 22
.dev_scripts/dockerci.sh View File

@@ -20,28 +20,52 @@ do


# pull image if there are update # pull image if there are update
docker pull ${IMAGE_NAME}:${IMAGE_VERSION} docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
--cpuset-cpus=${cpu_sets_arr[$gpu]} \
--gpus="device=$gpu" \
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \
-v /home/admin/pre-commit:/home/admin/pre-commit \
-e CI_TEST=True \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \
--net host \
${IMAGE_NAME}:${IMAGE_VERSION} \
$CI_COMMAND

if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
--cpuset-cpus=${cpu_sets_arr[$gpu]} \
--gpus="device=$gpu" \
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \
-v /home/admin/pre-commit:/home/admin/pre-commit \
-e CI_TEST=True \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
-e MODELSCOPE_SDK_DEBUG=True \
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \
--net host \
${IMAGE_NAME}:${IMAGE_VERSION} \
$CI_COMMAND
else
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
--cpuset-cpus=${cpu_sets_arr[$gpu]} \
--gpus="device=$gpu" \
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \
-v /home/admin/pre-commit:/home/admin/pre-commit \
-e CI_TEST=True \
-e TEST_LEVEL=$TEST_LEVEL \
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
-e TEST_LEVEL=$TEST_LEVEL \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
--workdir=$CODE_DIR_IN_CONTAINER \
--net host \
${IMAGE_NAME}:${IMAGE_VERSION} \
$CI_COMMAND
fi
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Running test case failed, please check the log!" echo "Running test case failed, please check the log!"
exit -1 exit -1


+ 2
- 2
modelscope/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .version import __version__
from .version import __release_datetime__, __version__


__all__ = ['__version__']
__all__ = ['__version__', '__release_datetime__']

+ 155
- 35
modelscope/hub/api.py View File

@@ -4,40 +4,45 @@
import datetime import datetime
import os import os
import pickle import pickle
import platform
import shutil import shutil
import tempfile import tempfile
import uuid
from collections import defaultdict from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from os.path import expanduser from os.path import expanduser
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union


import requests import requests


from modelscope import __version__
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_EMAIL, API_RESPONSE_FIELD_EMAIL,
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
API_RESPONSE_FIELD_MESSAGE, API_RESPONSE_FIELD_MESSAGE,
API_RESPONSE_FIELD_USERNAME, API_RESPONSE_FIELD_USERNAME,
DEFAULT_CREDENTIALS_PATH, Licenses,
ModelVisibility)
DEFAULT_CREDENTIALS_PATH,
MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS,
Licenses, ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError, from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, RequestError,
datahub_raise_on_error,
NotLoginException, NoValidRevisionError,
RequestError, datahub_raise_on_error,
handle_http_post_error, handle_http_post_error,
handle_http_response, is_ok, raise_on_error)
handle_http_response, is_ok,
raise_for_http_status, raise_on_error)
from modelscope.hub.git import GitCommandWrapper from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
from modelscope.hub.utils.utils import (get_endpoint,
model_id_to_group_owner_name)
from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, DEFAULT_MODEL_REVISION,
DatasetFormations, DatasetMetaFormats,
DownloadMode, ModelFile)
DEFAULT_REPOSITORY_REVISION,
MASTER_MODEL_BRANCH, DatasetFormations,
DatasetMetaFormats, DownloadMode,
ModelFile)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger

# yapf: enable
from .utils.utils import (get_endpoint, get_release_datetime,
model_id_to_group_owner_name)


logger = get_logger() logger = get_logger()


@@ -46,6 +51,7 @@ class HubApi:


def __init__(self, endpoint=None): def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else get_endpoint() self.endpoint = endpoint if endpoint is not None else get_endpoint()
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}


def login( def login(
self, self,
@@ -65,8 +71,9 @@ class HubApi:
</Tip> </Tip>
""" """
path = f'{self.endpoint}/api/v1/login' path = f'{self.endpoint}/api/v1/login'
r = requests.post(path, json={'AccessToken': access_token})
r.raise_for_status()
r = requests.post(
path, json={'AccessToken': access_token}, headers=self.headers)
raise_for_http_status(r)
d = r.json() d = r.json()
raise_on_error(d) raise_on_error(d)


@@ -120,7 +127,8 @@ class HubApi:
'Visibility': visibility, # server check 'Visibility': visibility, # server check
'License': license 'License': license
} }
r = requests.post(path, json=body, cookies=cookies)
r = requests.post(
path, json=body, cookies=cookies, headers=self.headers)
handle_http_post_error(r, path, body) handle_http_post_error(r, path, body)
raise_on_error(r.json()) raise_on_error(r.json())
model_repo_url = f'{get_endpoint()}/{model_id}' model_repo_url = f'{get_endpoint()}/{model_id}'
@@ -140,8 +148,8 @@ class HubApi:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
path = f'{self.endpoint}/api/v1/models/{model_id}' path = f'{self.endpoint}/api/v1/models/{model_id}'


r = requests.delete(path, cookies=cookies)
r.raise_for_status()
r = requests.delete(path, cookies=cookies, headers=self.headers)
raise_for_http_status(r)
raise_on_error(r.json()) raise_on_error(r.json())


def get_model_url(self, model_id): def get_model_url(self, model_id):
@@ -170,7 +178,7 @@ class HubApi:
owner_or_group, name = model_id_to_group_owner_name(model_id) owner_or_group, name = model_id_to_group_owner_name(model_id)
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}'


r = requests.get(path, cookies=cookies)
r = requests.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id) handle_http_response(r, logger, cookies, model_id)
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
@@ -178,7 +186,7 @@ class HubApi:
else: else:
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)


def push_model(self, def push_model(self,
model_id: str, model_id: str,
@@ -187,7 +195,7 @@ class HubApi:
license: str = Licenses.APACHE_V2, license: str = Licenses.APACHE_V2,
chinese_name: Optional[str] = None, chinese_name: Optional[str] = None,
commit_message: Optional[str] = 'upload model', commit_message: Optional[str] = 'upload model',
revision: Optional[str] = DEFAULT_MODEL_REVISION):
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION):
""" """
Upload model from a given directory to given repository. A valid model directory Upload model from a given directory to given repository. A valid model directory
must contain a configuration.json file. must contain a configuration.json file.
@@ -269,7 +277,7 @@ class HubApi:
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
commit_message = '[automsg] push model %s to hub at %s' % ( commit_message = '[automsg] push model %s to hub at %s' % (
model_id, date) model_id, date)
repo.push(commit_message=commit_message, branch=revision)
repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision)
except Exception: except Exception:
raise raise
finally: finally:
@@ -294,7 +302,8 @@ class HubApi:
path, path,
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
(owner_or_group, page_number, page_size), (owner_or_group, page_number, page_size),
cookies=cookies)
cookies=cookies,
headers=self.headers)
handle_http_response(r, logger, cookies, 'list_model') handle_http_response(r, logger, cookies, 'list_model')
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
@@ -303,7 +312,7 @@ class HubApi:
else: else:
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)
return None return None


def _check_cookie(self, def _check_cookie(self,
@@ -318,10 +327,70 @@ class HubApi:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
return cookies return cookies


def list_model_revisions(
self,
model_id: str,
cutoff_timestamp: int = None,
use_cookies: Union[bool, CookieJar] = False) -> List[str]:
"""Get model branch and tags.

Args:
model_id (str): The model id
cutoff_timestamp (int): Tags created before the cutoff will be included.
The timestamp is represented by the seconds elasped from the epoch time.
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
Returns:
Tuple[List[str], List[str]]: Return list of branch name and tags
"""
cookies = self._check_cookie(use_cookies)
if cutoff_timestamp is None:
cutoff_timestamp = get_release_datetime()
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
r = requests.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id)
d = r.json()
raise_on_error(d)
info = d[API_RESPONSE_FIELD_DATA]
# tags returned from backend are guaranteed to be ordered by create-time
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
] if info['RevisionMap']['Tags'] else []
return tags

def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None):
release_timestamp = get_release_datetime()
current_timestamp = int(round(datetime.datetime.now().timestamp()))
# for active development in library codes (non-release-branches), release_timestamp
# is set to be a far-away-time-in-the-future, to ensure that we shall
# get the master-HEAD version from model repo by default (when no revision is provided)
if release_timestamp > current_timestamp + ONE_YEAR_SECONDS:
branches, tags = self.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
if revision is None:
revision = MASTER_MODEL_BRANCH
logger.info('Model revision not specified, use default: %s in development mode' % revision)
if revision not in branches and revision not in tags:
raise NotExistError('The model: %s has no branch or tag : %s .' % revision)
else:
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies)
if revision is None:
if len(revisions) == 0:
raise NoValidRevisionError('The model: %s has no valid revision!' % model_id)
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
# we shall obtain the latest revision created earlier than release version of this branch
revision = revisions[0]
logger.info('Model revision not specified, use the latest revision: %s' % revision)
else:
if revision not in revisions:
raise NotExistError(
'The model: %s has no revision: %s !' % (model_id, revision))
return revision

def get_model_branches_and_tags( def get_model_branches_and_tags(
self, self,
model_id: str, model_id: str,
use_cookies: Union[bool, CookieJar] = False
use_cookies: Union[bool, CookieJar] = False,
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
"""Get model branch and tags. """Get model branch and tags.


@@ -335,7 +404,7 @@ class HubApi:
cookies = self._check_cookie(use_cookies) cookies = self._check_cookie(use_cookies)


path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies)
r = requests.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id) handle_http_response(r, logger, cookies, model_id)
d = r.json() d = r.json()
raise_on_error(d) raise_on_error(d)
@@ -376,7 +445,11 @@ class HubApi:
if root is not None: if root is not None:
path = path + f'&Root={root}' path = path + f'&Root={root}'


r = requests.get(path, cookies=cookies, headers=headers)
r = requests.get(
path, cookies=cookies, headers={
**headers,
**self.headers
})


handle_http_response(r, logger, cookies, model_id) handle_http_response(r, logger, cookies, model_id)
d = r.json() d = r.json()
@@ -392,10 +465,9 @@ class HubApi:


def list_datasets(self): def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets' path = f'{self.endpoint}/api/v1/datasets'
headers = None
params = {} params = {}
r = requests.get(path, params=params, headers=headers)
r.raise_for_status()
r = requests.get(path, params=params, headers=self.headers)
raise_for_http_status(r)
dataset_list = r.json()[API_RESPONSE_FIELD_DATA] dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
return [x['Name'] for x in dataset_list] return [x['Name'] for x in dataset_list]


@@ -425,7 +497,7 @@ class HubApi:
dataset_id = resp['Data']['Id'] dataset_id = resp['Data']['Id']
dataset_type = resp['Data']['Type'] dataset_type = resp['Data']['Type']
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
r = requests.get(datahub_url)
r = requests.get(datahub_url, headers=self.headers)
resp = r.json() resp = r.json()
datahub_raise_on_error(datahub_url, resp) datahub_raise_on_error(datahub_url, resp)
file_list = resp['Data'] file_list = resp['Data']
@@ -445,7 +517,7 @@ class HubApi:
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_path}' f'Revision={revision}&FilePath={file_path}'
r = requests.get(datahub_url) r = requests.get(datahub_url)
r.raise_for_status()
raise_for_http_status(r)
local_path = os.path.join(cache_dir, file_path) local_path = os.path.join(cache_dir, file_path)
if os.path.exists(local_path): if os.path.exists(local_path):
logger.warning( logger.warning(
@@ -490,7 +562,8 @@ class HubApi:
f'ststoken?Revision={revision}' f'ststoken?Revision={revision}'


cookies = requests.utils.dict_from_cookiejar(cookies) cookies = requests.utils.dict_from_cookiejar(cookies)
r = requests.get(url=datahub_url, cookies=cookies)
r = requests.get(
url=datahub_url, cookies=cookies, headers=self.headers)
resp = r.json() resp = r.json()
raise_on_error(resp) raise_on_error(resp)
return resp['Data'] return resp['Data']
@@ -512,12 +585,12 @@ class HubApi:


def on_dataset_download(self, dataset_name: str, namespace: str) -> None: def on_dataset_download(self, dataset_name: str, namespace: str) -> None:
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
r = requests.post(url)
r.raise_for_status()
r = requests.post(url, headers=self.headers)
raise_for_http_status(r)


@staticmethod @staticmethod
def datahub_remote_call(url): def datahub_remote_call(url):
r = requests.get(url)
r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()})
resp = r.json() resp = r.json()
datahub_raise_on_error(url, resp) datahub_raise_on_error(url, resp)
return resp['Data'] return resp['Data']
@@ -531,6 +604,7 @@ class ModelScopeConfig:
COOKIES_FILE_NAME = 'cookies' COOKIES_FILE_NAME = 'cookies'
GIT_TOKEN_FILE_NAME = 'git_token' GIT_TOKEN_FILE_NAME = 'git_token'
USER_INFO_FILE_NAME = 'user' USER_INFO_FILE_NAME = 'user'
USER_SESSION_ID_FILE_NAME = 'session'


@staticmethod @staticmethod
def make_sure_credential_path_exist(): def make_sure_credential_path_exist():
@@ -559,6 +633,23 @@ class ModelScopeConfig:
return cookies return cookies
return None return None


@staticmethod
def get_user_session_id():
session_path = os.path.join(ModelScopeConfig.path_credential,
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
session_id = ''
if os.path.exists(session_path):
with open(session_path, 'rb') as f:
session_id = str(f.readline().strip(), encoding='utf-8')
return session_id
if session_id == '' or len(session_id) != 32:
session_id = str(uuid.uuid4().hex)
ModelScopeConfig.make_sure_credential_path_exist()
with open(session_path, 'w+') as wf:
wf.write(session_id)

return session_id

@staticmethod @staticmethod
def save_token(token: str): def save_token(token: str):
ModelScopeConfig.make_sure_credential_path_exist() ModelScopeConfig.make_sure_credential_path_exist()
@@ -607,3 +698,32 @@ class ModelScopeConfig:
except FileNotFoundError: except FileNotFoundError:
pass pass
return token return token

@staticmethod
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
"""Formats a user-agent string with basic info about a request.

Args:
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.

Returns:
The formatted user-agent string.
"""
env = 'custom'
if MODELSCOPE_ENVIRONMENT in os.environ:
env = os.environ[MODELSCOPE_ENVIRONMENT]

ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % (
__version__,
platform.python_version(),
ModelScopeConfig.get_user_session_id(),
platform.platform(),
platform.processor(),
env,
)
if isinstance(user_agent, dict):
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += ';' + user_agent
return ua

+ 3
- 0
modelscope/hub/constants.py View File

@@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
API_RESPONSE_FIELD_USERNAME = 'Username' API_RESPONSE_FIELD_USERNAME = 'Username'
API_RESPONSE_FIELD_EMAIL = 'Email' API_RESPONSE_FIELD_EMAIL = 'Email'
API_RESPONSE_FIELD_MESSAGE = 'Message' API_RESPONSE_FIELD_MESSAGE = 'Message'
MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60




class Licenses(object): class Licenses(object):


+ 12
- 14
modelscope/hub/deploy.py View File

@@ -3,7 +3,6 @@ from abc import ABC
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Optional


import attrs
import json import json
import requests import requests
from attrs import asdict, define, field, validators from attrs import asdict, define, field, validators
@@ -12,7 +11,8 @@ from modelscope.hub.api import ModelScopeConfig
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_MESSAGE) API_RESPONSE_FIELD_MESSAGE)
from modelscope.hub.errors import (NotLoginException, NotSupportError, from modelscope.hub.errors import (NotLoginException, NotSupportError,
RequestError, handle_http_response, is_ok)
RequestError, handle_http_response, is_ok,
raise_for_http_status)
from modelscope.hub.utils.utils import get_endpoint from modelscope.hub.utils.utils import get_endpoint
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -188,6 +188,7 @@ class ServiceDeployer(object):


def __init__(self, endpoint=None): def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else get_endpoint() self.endpoint = endpoint if endpoint is not None else get_endpoint()
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
self.cookies = ModelScopeConfig.get_cookies() self.cookies = ModelScopeConfig.get_cookies()
if self.cookies is None: if self.cookies is None:
raise NotLoginException( raise NotLoginException(
@@ -227,12 +228,9 @@ class ServiceDeployer(object):
resource=resource, resource=resource,
provider=provider) provider=provider)
path = f'{self.endpoint}/api/v1/deployer/endpoint' path = f'{self.endpoint}/api/v1/deployer/endpoint'
body = attrs.asdict(create_params)
body = asdict(create_params)
r = requests.post( r = requests.post(
path,
json=body,
cookies=self.cookies,
)
path, json=body, cookies=self.cookies, headers=self.headers)
handle_http_response(r, logger, self.cookies, 'create_service') handle_http_response(r, logger, self.cookies, 'create_service')
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES:
if is_ok(r.json()): if is_ok(r.json()):
@@ -241,7 +239,7 @@ class ServiceDeployer(object):
else: else:
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)
return None return None


def get(self, instance_name: str, provider: ServiceProviderParameters): def get(self, instance_name: str, provider: ServiceProviderParameters):
@@ -262,7 +260,7 @@ class ServiceDeployer(object):
params = GetServiceParameters(provider=provider) params = GetServiceParameters(provider=provider)
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( path = '%s/api/v1/deployer/endpoint/%s?%s' % (
self.endpoint, instance_name, params.to_query_str()) self.endpoint, instance_name, params.to_query_str())
r = requests.get(path, cookies=self.cookies)
r = requests.get(path, cookies=self.cookies, headers=self.headers)
handle_http_response(r, logger, self.cookies, 'get_service') handle_http_response(r, logger, self.cookies, 'get_service')
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
@@ -271,7 +269,7 @@ class ServiceDeployer(object):
else: else:
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)
return None return None


def delete(self, instance_name: str, provider: ServiceProviderParameters): def delete(self, instance_name: str, provider: ServiceProviderParameters):
@@ -293,7 +291,7 @@ class ServiceDeployer(object):
params = DeleteServiceParameters(provider=provider) params = DeleteServiceParameters(provider=provider)
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( path = '%s/api/v1/deployer/endpoint/%s?%s' % (
self.endpoint, instance_name, params.to_query_str()) self.endpoint, instance_name, params.to_query_str())
r = requests.delete(path, cookies=self.cookies)
r = requests.delete(path, cookies=self.cookies, headers=self.headers)
handle_http_response(r, logger, self.cookies, 'delete_service') handle_http_response(r, logger, self.cookies, 'delete_service')
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
@@ -302,7 +300,7 @@ class ServiceDeployer(object):
else: else:
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)
return None return None


def list(self, def list(self,
@@ -328,7 +326,7 @@ class ServiceDeployer(object):
provider=provider, skip=skip, limit=limit) provider=provider, skip=skip, limit=limit)
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint,
params.to_query_str()) params.to_query_str())
r = requests.get(path, cookies=self.cookies)
r = requests.get(path, cookies=self.cookies, headers=self.headers)
handle_http_response(r, logger, self.cookies, 'list_service_instances') handle_http_response(r, logger, self.cookies, 'list_service_instances')
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
@@ -337,5 +335,5 @@ class ServiceDeployer(object):
else: else:
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
else: else:
r.raise_for_status()
raise_for_http_status(r)
return None return None

+ 34
- 0
modelscope/hub/errors.py View File

@@ -13,6 +13,10 @@ class NotSupportError(Exception):
pass pass




class NoValidRevisionError(Exception):
pass


class NotExistError(Exception): class NotExistError(Exception):
pass pass


@@ -99,3 +103,33 @@ def datahub_raise_on_error(url, rsp):
raise RequestError( raise RequestError(
f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}"
) )


def raise_for_http_status(rsp):
"""
Attempt to decode utf-8 first since some servers
localize reason strings, for invalid utf-8, fall back
to decoding with iso-8859-1.
"""
http_error_msg = ''
if isinstance(rsp.reason, bytes):
try:
reason = rsp.reason.decode('utf-8')
except UnicodeDecodeError:
reason = rsp.reason.decode('iso-8859-1')
else:
reason = rsp.reason

if 400 <= rsp.status_code < 500:
http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code,
reason, rsp.url)

elif 500 <= rsp.status_code < 600:
http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code,
reason, rsp.url)

if http_error_msg:
req = rsp.request
if req.method == 'POST':
http_error_msg = u'%s, body: %s' % (http_error_msg, req.body)
raise HTTPError(http_error_msg, response=rsp)

+ 38
- 66
modelscope/hub/file_download.py View File

@@ -2,13 +2,11 @@


import copy import copy
import os import os
import sys
import tempfile import tempfile
from functools import partial from functools import partial
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from uuid import uuid4


import requests import requests
from tqdm import tqdm from tqdm import tqdm
@@ -23,7 +21,6 @@ from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_cache_dir, from .utils.utils import (file_integrity_validation, get_cache_dir,
get_endpoint, model_id_to_group_owner_name) get_endpoint, model_id_to_group_owner_name)


SESSION_ID = uuid4().hex
logger = get_logger() logger = get_logger()




@@ -34,6 +31,7 @@ def model_file_download(
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
user_agent: Union[Dict, str, None] = None, user_agent: Union[Dict, str, None] = None,
local_files_only: Optional[bool] = False, local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
) -> Optional[str]: # pragma: no cover ) -> Optional[str]: # pragma: no cover
""" """
Download from a given URL and cache it if it's not already present in the Download from a given URL and cache it if it's not already present in the
@@ -104,54 +102,47 @@ def model_file_download(
" online, set 'local_files_only' to False.") " online, set 'local_files_only' to False.")


_api = HubApi() _api = HubApi()
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
cookies = ModelScopeConfig.get_cookies()
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
headers = {
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
if cookies is None:
cookies = ModelScopeConfig.get_cookies()

revision = _api.get_valid_revision(
model_id, revision=revision, cookies=cookies)
file_to_download_info = None file_to_download_info = None
is_commit_id = False
if revision in branches or revision in tags: # The revision is version or tag,
# we need to confirm the version is up to date
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue

if model_file['Path'] == file_path:
if cache.exists(model_file):
return cache.get_file_by_info(model_file)
else:
file_to_download_info = model_file
break

if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, model_id))
else: # the revision is commit id.
cached_file_path = cache.get_file_by_path_and_commit_id(
file_path, revision)
if cached_file_path is not None:
file_name = os.path.basename(cached_file_path)
logger.info(
f'File {file_name} already in cache, skip downloading!')
return cached_file_path # the file is in cache.
is_commit_id = True
# we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
model_files = _api.get_model_files(
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=False if cookies is None else cookies)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue

if model_file['Path'] == file_path:
if cache.exists(model_file):
logger.info(
f'File {model_file["Name"]} already in cache, skip downloading!'
)
return cache.get_file_by_info(model_file)
else:
file_to_download_info = model_file
break

if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' %
(file_path, model_id))

# we need to download again # we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision) url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = { file_to_download_info = {
'Path':
file_path,
'Revision':
revision if is_commit_id else file_to_download_info['Revision'],
FILE_HASH:
None if (is_commit_id or FILE_HASH not in file_to_download_info) else
file_to_download_info[FILE_HASH]
'Path': file_path,
'Revision': file_to_download_info['Revision'],
FILE_HASH: file_to_download_info[FILE_HASH]
} }


temp_file_name = next(tempfile._get_candidate_names()) temp_file_name = next(tempfile._get_candidate_names())
@@ -170,25 +161,6 @@ def model_file_download(
os.path.join(temporary_cache_dir, temp_file_name)) os.path.join(temporary_cache_dir, temp_file_name))




def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
"""Formats a user-agent string with basic info about a request.

Args:
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.

Returns:
The formatted user-agent string.
"""
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'

if isinstance(user_agent, dict):
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua = user_agent
return ua


def get_file_download_url(model_id: str, file_path: str, revision: str): def get_file_download_url(model_id: str, file_path: str, revision: str):
""" """
Format file download url according to `model_id`, `revision` and `file_path`. Format file download url according to `model_id`, `revision` and `file_path`.


+ 20
- 0
modelscope/hub/git.py View File

@@ -5,6 +5,7 @@ import subprocess
from typing import List from typing import List


from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from ..utils.constant import MASTER_MODEL_BRANCH
from .errors import GitError from .errors import GitError


logger = get_logger() logger = get_logger()
@@ -227,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton):
files.append(line.split(' ')[-1]) files.append(line.split(' ')[-1])


return files return files

def tag(self,
repo_dir: str,
tag_name: str,
message: str,
ref: str = MASTER_MODEL_BRANCH):
cmd_args = [
'-C', repo_dir, 'tag', tag_name, '-m',
'"%s"' % message, ref
]
rsp = self._run_git_command(*cmd_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp

def push_tag(self, repo_dir: str, tag_name):
cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name]
rsp = self._run_git_command(*cmd_args)
logger.debug(rsp.stdout.decode('utf8'))
return rsp

+ 41
- 5
modelscope/hub/repository.py View File

@@ -5,7 +5,8 @@ from typing import Optional


from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION)
DEFAULT_REPOSITORY_REVISION,
MASTER_MODEL_BRANCH)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .git import GitCommandWrapper from .git import GitCommandWrapper
from .utils.utils import get_endpoint from .utils.utils import get_endpoint
@@ -20,7 +21,7 @@ class Repository:
def __init__(self, def __init__(self,
model_dir: str, model_dir: str,
clone_from: str, clone_from: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
git_path: Optional[str] = None): git_path: Optional[str] = None):
""" """
@@ -89,7 +90,8 @@ class Repository:


def push(self, def push(self,
commit_message: str, commit_message: str,
branch: Optional[str] = DEFAULT_MODEL_REVISION,
local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
force: bool = False): force: bool = False):
"""Push local files to remote, this method will do. """Push local files to remote, this method will do.
git pull git pull
@@ -116,14 +118,48 @@ class Repository:


url = self.git_wrapper.get_repo_remote_url(self.model_dir) url = self.git_wrapper.get_repo_remote_url(self.model_dir)
self.git_wrapper.pull(self.model_dir) self.git_wrapper.pull(self.model_dir)

self.git_wrapper.add(self.model_dir, all_files=True) self.git_wrapper.add(self.model_dir, all_files=True)
self.git_wrapper.commit(self.model_dir, commit_message) self.git_wrapper.commit(self.model_dir, commit_message)
self.git_wrapper.push( self.git_wrapper.push(
repo_dir=self.model_dir, repo_dir=self.model_dir,
token=self.auth_token, token=self.auth_token,
url=url, url=url,
local_branch=branch,
remote_branch=branch)
local_branch=local_branch,
remote_branch=remote_branch)

def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH):
"""Create a new tag.
Args:
tag_name (str): The name of the tag
message (str): The tag message.
ref (str): The tag reference, can be commit id or branch.
"""
if tag_name is None or tag_name == '':
msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.'
raise InvalidParameter(msg)
if message is None or message == '':
msg = 'We use annotated tag, therefore message cannot None or empty.'
self.git_wrapper.tag(
repo_dir=self.model_dir,
tag_name=tag_name,
message=message,
ref=ref)

def tag_and_push(self,
tag_name: str,
message: str,
ref: str = MASTER_MODEL_BRANCH):
"""Create tag and push to remote

Args:
tag_name (str): The name of the tag
message (str): The tag message.
ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH.
"""
self.tag(tag_name, message, ref)

self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name)




class DatasetRepository: class DatasetRepository:


+ 13
- 13
modelscope/hub/snapshot_download.py View File

@@ -2,6 +2,7 @@


import os import os
import tempfile import tempfile
from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union


@@ -9,9 +10,7 @@ from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .constants import FILE_HASH from .constants import FILE_HASH
from .errors import NotExistError
from .file_download import (get_file_download_url, http_get_file,
http_user_agent)
from .file_download import get_file_download_url, http_get_file
from .utils.caching import ModelFileSystemCache from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_cache_dir, from .utils.utils import (file_integrity_validation, get_cache_dir,
model_id_to_group_owner_name) model_id_to_group_owner_name)
@@ -23,7 +22,8 @@ def snapshot_download(model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION, revision: Optional[str] = DEFAULT_MODEL_REVISION,
cache_dir: Union[str, Path, None] = None, cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None, user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False) -> str:
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None) -> str:
"""Download all files of a repo. """Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which is useful when you want all files from a repo, because you don't know which
@@ -81,15 +81,15 @@ def snapshot_download(model_id: str,
) # we can not confirm the cached file is for snapshot 'revision' ) # we can not confirm the cached file is for snapshot 'revision'
else: else:
# make headers # make headers
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
headers = {
'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
}
_api = HubApi() _api = HubApi()
cookies = ModelScopeConfig.get_cookies()
# get file list from model repo
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
if revision not in branches and revision not in tags:
raise NotExistError('The specified branch or tag : %s not exist!'
% revision)
if cookies is None:
cookies = ModelScopeConfig.get_cookies()
revision = _api.get_valid_revision(
model_id, revision=revision, cookies=cookies)


snapshot_header = headers if 'CI_TEST' in os.environ else { snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers, **headers,
@@ -110,7 +110,7 @@ def snapshot_download(model_id: str,
for model_file in model_files: for model_file in model_files:
if model_file['Type'] == 'tree': if model_file['Type'] == 'tree':
continue continue
# check model_file is exist in cache, if exist, skip download, otherwise download
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file): if cache.exists(model_file):
file_name = os.path.basename(model_file['Name']) file_name = os.path.basename(model_file['Name'])
logger.info( logger.info(


+ 14
- 1
modelscope/hub/utils/utils.py View File

@@ -2,11 +2,12 @@


import hashlib import hashlib
import os import os
from datetime import datetime
from typing import Optional from typing import Optional


from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP, DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR,
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
MODELSCOPE_URL_SCHEME) MODELSCOPE_URL_SCHEME)
from modelscope.hub.errors import FileIntegrityError from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.file_utils import get_default_cache_dir from modelscope.utils.file_utils import get_default_cache_dir
@@ -37,6 +38,18 @@ def get_cache_dir(model_id: Optional[str] = None):
base_path, model_id + '/') base_path, model_id + '/')




def get_release_datetime():
if MODELSCOPE_SDK_DEBUG in os.environ:
rt = int(round(datetime.now().timestamp()))
else:
from modelscope import version
rt = int(
round(
datetime.strptime(version.__release_datetime__,
'%Y-%m-%d %H:%M:%S').timestamp()))
return rt


def get_endpoint(): def get_endpoint():
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN) DEFAULT_MODELSCOPE_DOMAIN)


+ 5
- 1
modelscope/msdatasets/ms_dataset.py View File

@@ -675,4 +675,8 @@ class MsDataset:
revision=revision, revision=revision,
auth_token=auth_token, auth_token=auth_token,
git_path=git_path) git_path=git_path)
_repo.push(commit_message=commit_message, branch=revision, force=force)
_repo.push(
commit_message=commit_message,
local_branch=revision,
remote_branch=revision,
force=force)

+ 3
- 1
modelscope/utils/constant.py View File

@@ -311,7 +311,9 @@ class Frameworks(object):
kaldi = 'kaldi' kaldi = 'kaldi'




DEFAULT_MODEL_REVISION = 'master'
DEFAULT_MODEL_REVISION = None
MASTER_MODEL_BRANCH = 'master'
DEFAULT_REPOSITORY_REVISION = 'master'
DEFAULT_DATASET_REVISION = 'master' DEFAULT_DATASET_REVISION = 'master'
DEFAULT_DATASET_NAMESPACE = 'modelscope' DEFAULT_DATASET_NAMESPACE = 'modelscope'




+ 4
- 0
modelscope/version.py View File

@@ -1 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '0.5.0' __version__ = '0.5.0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-10-13 08:56:12'

+ 25
- 11
tests/hub/test_hub_operation.py View File

@@ -25,10 +25,10 @@ class HubOperationTest(unittest.TestCase):


def setUp(self): def setUp(self):
self.api = HubApi() self.api = HubApi()
# note this is temporary before official account management is ready
self.api.login(TEST_ACCESS_TOKEN1) self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = uuid.uuid4().hex
self.model_name = 'op-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PUBLIC, visibility=ModelVisibility.PUBLIC,
@@ -46,6 +46,7 @@ class HubOperationTest(unittest.TestCase):
os.system("echo 'testtest'>%s" os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name)) % os.path.join(self.model_dir, download_model_file_name))
repo.push('add model') repo.push('add model')
repo.tag_and_push(self.revision, 'Test revision')


def test_model_repo_creation(self): def test_model_repo_creation(self):
# change to proper model names before use # change to proper model names before use
@@ -61,7 +62,9 @@ class HubOperationTest(unittest.TestCase):
def test_download_single_file(self): def test_download_single_file(self):
self.prepare_case() self.prepare_case()
downloaded_file = model_file_download( downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(downloaded_file) assert os.path.exists(downloaded_file)
mdtime1 = os.path.getmtime(downloaded_file) mdtime1 = os.path.getmtime(downloaded_file)
# download again # download again
@@ -78,17 +81,16 @@ class HubOperationTest(unittest.TestCase):
assert os.path.exists(downloaded_file_path) assert os.path.exists(downloaded_file_path)
mdtime1 = os.path.getmtime(downloaded_file_path) mdtime1 = os.path.getmtime(downloaded_file_path)
# download again # download again
snapshot_path = snapshot_download(model_id=self.model_id)
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
mdtime2 = os.path.getmtime(downloaded_file_path) mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2 assert mdtime1 == mdtime2
model_file_download(
model_id=self.model_id,
file_path=download_model_file_name) # not add counter


def test_download_public_without_login(self): def test_download_public_without_login(self):
self.prepare_case() self.prepare_case()
rmtree(ModelScopeConfig.path_credential) rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(model_id=self.model_id)
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path, downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name) download_model_file_name)
assert os.path.exists(downloaded_file_path) assert os.path.exists(downloaded_file_path)
@@ -96,26 +98,38 @@ class HubOperationTest(unittest.TestCase):
downloaded_file = model_file_download( downloaded_file = model_file_download(
model_id=self.model_id, model_id=self.model_id,
file_path=download_model_file_name, file_path=download_model_file_name,
revision=self.revision,
cache_dir=temporary_dir) cache_dir=temporary_dir)
assert os.path.exists(downloaded_file) assert os.path.exists(downloaded_file)
self.api.login(TEST_ACCESS_TOKEN1) self.api.login(TEST_ACCESS_TOKEN1)


def test_snapshot_delete_download_cache_file(self): def test_snapshot_delete_download_cache_file(self):
self.prepare_case() self.prepare_case()
snapshot_path = snapshot_download(model_id=self.model_id)
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path, downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name) download_model_file_name)
assert os.path.exists(downloaded_file_path) assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path) os.remove(downloaded_file_path)
# download again in cache # download again in cache
file_download_path = model_file_download( file_download_path = model_file_download(
model_id=self.model_id, file_path=ModelFile.README)
model_id=self.model_id,
file_path=ModelFile.README,
revision=self.revision)
assert os.path.exists(file_download_path) assert os.path.exists(file_download_path)
# deleted file need download again # deleted file need download again
file_download_path = model_file_download( file_download_path = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(file_download_path) assert os.path.exists(file_download_path)


def test_snapshot_download_default_revision(self):
pass # TOTO

def test_file_download_default_revision(self):
pass # TODO

def get_model_download_times(self): def get_model_download_times(self):
url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()


+ 44
- 15
tests/hub/test_hub_private_files.py View File

@@ -17,23 +17,34 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG,
delete_credential) delete_credential)


download_model_file_name = 'test.bin'



class HubPrivateFileDownloadTest(unittest.TestCase): class HubPrivateFileDownloadTest(unittest.TestCase):


def setUp(self): def setUp(self):
self.old_cwd = os.getcwd() self.old_cwd = os.getcwd()
self.api = HubApi() self.api = HubApi()
# note this is temporary before official account management is ready
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) self.token, _ = self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = uuid.uuid4().hex
self.model_name = 'pf-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
visibility=ModelVisibility.PRIVATE,
license=Licenses.APACHE_V2, license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME, chinese_name=TEST_MODEL_CHINESE_NAME,
) )


def prepare_case(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')
repo.tag_and_push(self.revision, 'Test revision')

def tearDown(self): def tearDown(self):
# credential may deleted or switch login name, we need re-login here # credential may deleted or switch login name, we need re-login here
# to ensure the temporary model is deleted. # to ensure the temporary model is deleted.
@@ -42,49 +53,67 @@ class HubPrivateFileDownloadTest(unittest.TestCase):
self.api.delete_model(model_id=self.model_id) self.api.delete_model(model_id=self.model_id)


def test_snapshot_download_private_model(self): def test_snapshot_download_private_model(self):
snapshot_path = snapshot_download(self.model_id)
self.prepare_case()
snapshot_path = snapshot_download(self.model_id, self.revision)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))


def test_snapshot_download_private_model_no_permission(self): def test_snapshot_download_private_model_no_permission(self):
self.prepare_case()
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) self.token, _ = self.api.login(TEST_ACCESS_TOKEN2)
with self.assertRaises(HTTPError): with self.assertRaises(HTTPError):
snapshot_download(self.model_id)
snapshot_download(self.model_id, self.revision)


def test_snapshot_download_private_model_without_login(self): def test_snapshot_download_private_model_without_login(self):
self.prepare_case()
delete_credential() delete_credential()
with self.assertRaises(HTTPError): with self.assertRaises(HTTPError):
snapshot_download(self.model_id)
snapshot_download(self.model_id, self.revision)


def test_download_file_private_model(self): def test_download_file_private_model(self):
file_path = model_file_download(self.model_id, ModelFile.README)
self.prepare_case()
file_path = model_file_download(self.model_id, ModelFile.README,
self.revision)
assert os.path.exists(file_path) assert os.path.exists(file_path)


def test_download_file_private_model_no_permission(self): def test_download_file_private_model_no_permission(self):
self.prepare_case()
self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) self.token, _ = self.api.login(TEST_ACCESS_TOKEN2)
with self.assertRaises(HTTPError): with self.assertRaises(HTTPError):
model_file_download(self.model_id, ModelFile.README)
model_file_download(self.model_id, ModelFile.README, self.revision)


def test_download_file_private_model_without_login(self): def test_download_file_private_model_without_login(self):
self.prepare_case()
delete_credential() delete_credential()
with self.assertRaises(HTTPError): with self.assertRaises(HTTPError):
model_file_download(self.model_id, ModelFile.README)
model_file_download(self.model_id, ModelFile.README, self.revision)


def test_snapshot_download_local_only(self): def test_snapshot_download_local_only(self):
self.prepare_case()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
snapshot_download(self.model_id, local_files_only=True)
snapshot_path = snapshot_download(self.model_id)
snapshot_download(
self.model_id, self.revision, local_files_only=True)
snapshot_path = snapshot_download(self.model_id, self.revision)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))
snapshot_path = snapshot_download(self.model_id, local_files_only=True)
snapshot_path = snapshot_download(
self.model_id, self.revision, local_files_only=True)
assert os.path.exists(snapshot_path) assert os.path.exists(snapshot_path)


def test_file_download_local_only(self): def test_file_download_local_only(self):
self.prepare_case()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_file_download( model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
file_path = model_file_download(self.model_id, ModelFile.README)
self.model_id,
ModelFile.README,
self.revision,
local_files_only=True)
file_path = model_file_download(self.model_id, ModelFile.README,
self.revision)
assert os.path.exists(file_path) assert os.path.exists(file_path)
file_path = model_file_download( file_path = model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
self.model_id,
ModelFile.README,
revision=self.revision,
local_files_only=True)
assert os.path.exists(file_path) assert os.path.exists(file_path)






+ 2
- 3
tests/hub/test_hub_private_repository.py View File

@@ -21,13 +21,12 @@ class HubPrivateRepositoryTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.old_cwd = os.getcwd() self.old_cwd = os.getcwd()
self.api = HubApi() self.api = HubApi()
# note this is temporary before official account management is ready
self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) self.token, _ = self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = uuid.uuid4().hex
self.model_name = 'pr-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
visibility=ModelVisibility.PRIVATE,
license=Licenses.APACHE_V2, license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME, chinese_name=TEST_MODEL_CHINESE_NAME,
) )


+ 7
- 5
tests/hub/test_hub_repository.py View File

@@ -22,6 +22,7 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
logger = get_logger() logger = get_logger()
logger.setLevel('DEBUG') logger.setLevel('DEBUG')
DEFAULT_GIT_PATH = 'git' DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'




class HubRepositoryTest(unittest.TestCase): class HubRepositoryTest(unittest.TestCase):
@@ -29,13 +30,13 @@ class HubRepositoryTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.old_cwd = os.getcwd() self.old_cwd = os.getcwd()
self.api = HubApi() self.api = HubApi()
# note this is temporary before official account management is ready
self.api.login(TEST_ACCESS_TOKEN1) self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = uuid.uuid4().hex
self.model_name = 'repo-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
visibility=ModelVisibility.PUBLIC, # 1-private, 5-public
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2, license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME, chinese_name=TEST_MODEL_CHINESE_NAME,
) )
@@ -67,9 +68,10 @@ class HubRepositoryTest(unittest.TestCase):
os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1))
os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2))
repo.push('test') repo.push('test')
add1 = model_file_download(self.model_id, 'add1.py')
repo.tag_and_push(self.revision, 'Test revision')
add1 = model_file_download(self.model_id, 'add1.py', self.revision)
assert os.path.exists(add1) assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
add2 = model_file_download(self.model_id, 'add2.py', self.revision)
assert os.path.exists(add2) assert os.path.exists(add2)
# check lfs files. # check lfs files.
git_wrapper = GitCommandWrapper() git_wrapper = GitCommandWrapper()


+ 145
- 0
tests/hub/test_hub_revision.py View File

@@ -0,0 +1,145 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import unittest
import uuid
from datetime import datetime

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import NotExistError, NoValidRevisionError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)

logger = get_logger()
logger.setLevel('DEBUG')
download_model_file_name = 'test.bin'
download_model_file_name2 = 'test2.bin'


class HubRevisionTest(unittest.TestCase):

def setUp(self):
self.api = HubApi()
self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = 'rv-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.revision2 = 'v0.2_test_revision'
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
)

def tearDown(self):
self.api.delete_model(model_id=self.model_id)

def prepare_repo_data(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
self.repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
self.repo.push('add model')
self.repo.tag_and_push(self.revision, 'Test revision')

def test_no_tag(self):
with self.assertRaises(NoValidRevisionError):
snapshot_download(self.model_id, None)

with self.assertRaises(NoValidRevisionError):
model_file_download(self.model_id, ModelFile.README)

def test_with_only_one_tag(self):
self.prepare_repo_data()
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id, cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id, ModelFile.README, cache_dir=temp_cache_dir)
assert os.path.exists(file_path)

def add_new_file_and_tag(self):
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name2))
self.repo.push('add new file')
self.repo.tag_and_push(self.revision2, 'Test revision')

def test_snapshot_download_different_revision(self):
self.prepare_repo_data()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time stamp: %s' % t1)
snapshot_path = snapshot_download(self.model_id, self.revision)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
self.add_new_file_and_tag()
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert not os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))

def test_file_download_different_revision(self):
self.prepare_repo_data()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time stamp: %s' % t1)
file_path = model_file_download(self.model_id,
download_model_file_name,
self.revision)
assert os.path.exists(file_path)
self.add_new_file_and_tag()
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
with self.assertRaises(NotExistError):
model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision,
cache_dir=temp_cache_dir)

with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision2,
cache_dir=temp_cache_dir)
print('Downloaded file path: %s' % file_path)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)


if __name__ == '__main__':
unittest.main()

+ 190
- 0
tests/hub/test_hub_revision_release_mode.py View File

@@ -0,0 +1,190 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import time
import unittest
import uuid
from datetime import datetime
from unittest import mock

from modelscope import version
from modelscope.hub.api import HubApi
from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses,
ModelVisibility)
from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.logger import get_logger
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)

logger = get_logger()
logger.setLevel('DEBUG')
download_model_file_name = 'test.bin'
download_model_file_name2 = 'test2.bin'


class HubRevisionTest(unittest.TestCase):

def setUp(self):
self.api = HubApi()
self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = 'rvr-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.revision2 = 'v0.2_test_revision'
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
)
names_to_remove = {MODELSCOPE_SDK_DEBUG}
self.modified_environ = {
k: v
for k, v in os.environ.items() if k not in names_to_remove
}

def tearDown(self):
self.api.delete_model(model_id=self.model_id)

def prepare_repo_data(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
self.repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
self.repo.push('add model')

def prepare_repo_data_and_tag(self):
self.prepare_repo_data()
self.repo.tag_and_push(self.revision, 'Test revision')

def add_new_file_and_tag_to_repo(self):
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name2))
self.repo.push('add new file')
self.repo.tag_and_push(self.revision2, 'Test revision')

def add_new_file_and_branch_to_repo(self, branch_name):
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name2))
self.repo.push('add new file', remote_branch=branch_name)

def test_dev_mode_default_master(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data() # no tag, default get master
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id, cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)

def test_dev_mode_specify_branch(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data() # no tag, default get master
branch_name = 'test'
self.add_new_file_and_branch_to_repo(branch_name)
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=branch_name,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=branch_name,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)

def test_snapshot_download_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Secnod time: %s' % t2)
# set
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
logger.info('Setting __release_datetime__ to: %s' % t1)
version.__release_datetime__ = t1
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id, cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert not os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
version.__release_datetime__ = t2
logger.info('Setting __release_datetime__ to: %s' % t2)
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id, cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
finally:
version.__release_datetime__ = release_datetime_backup

def test_file_download_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time stamp: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Second time: %s' % t2)
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
version.__release_datetime__ = t1
logger.info('Setting __release_datetime__ to: %s' % t1)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
with self.assertRaises(NotExistError):
model_file_download(
self.model_id,
download_model_file_name2,
cache_dir=temp_cache_dir)
version.__release_datetime__ = t2
logger.info('Setting __release_datetime__ to: %s' % t2)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
cache_dir=temp_cache_dir)
print('Downloaded file path: %s' % file_path)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
download_model_file_name2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
finally:
version.__release_datetime__ = release_datetime_backup


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save