| @@ -1,22 +1,28 @@ | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| pip install -r requirements/tests.txt | |||||
| echo "Testing envs" | |||||
| printenv | |||||
| echo "ENV END" | |||||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| pip install -r requirements/tests.txt | |||||
| git config --global --add safe.directory /Maas-lib | |||||
| git config --global --add safe.directory /Maas-lib | |||||
| # linter test | |||||
| # use internal project for pre-commit due to the network problem | |||||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||||
| if [ $? -ne 0 ]; then | |||||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||||
| exit -1 | |||||
| # linter test | |||||
| # use internal project for pre-commit due to the network problem | |||||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||||
| if [ $? -ne 0 ]; then | |||||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||||
| exit -1 | |||||
| fi | |||||
| # test with install | |||||
| python setup.py install | |||||
| else | |||||
| echo "Running case in release image, run case directly!" | |||||
| fi | fi | ||||
| # test with install | |||||
| python setup.py install | |||||
| if [ $# -eq 0 ]; then | if [ $# -eq 0 ]; then | ||||
| ci_command="python tests/run.py --subprocess" | ci_command="python tests/run.py --subprocess" | ||||
| else | else | ||||
| @@ -20,28 +20,52 @@ do | |||||
| # pull image if there are update | # pull image if there are update | ||||
| docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | ||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e MODELSCOPE_SDK_DEBUG=True \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| else | |||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| fi | |||||
| if [ $? -ne 0 ]; then | if [ $? -ne 0 ]; then | ||||
| echo "Running test case failed, please check the log!" | echo "Running test case failed, please check the log!" | ||||
| exit -1 | exit -1 | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .version import __version__ | |||||
| from .version import __release_datetime__, __version__ | |||||
| __all__ = ['__version__'] | |||||
| __all__ = ['__version__', '__release_datetime__'] | |||||
| @@ -4,40 +4,45 @@ | |||||
| import datetime | import datetime | ||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import platform | |||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| import uuid | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| from os.path import expanduser | from os.path import expanduser | ||||
| from typing import List, Optional, Tuple, Union | |||||
| from typing import Dict, List, Optional, Tuple, Union | |||||
| import requests | import requests | ||||
| from modelscope import __version__ | |||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | ||||
| API_RESPONSE_FIELD_EMAIL, | API_RESPONSE_FIELD_EMAIL, | ||||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | ||||
| API_RESPONSE_FIELD_MESSAGE, | API_RESPONSE_FIELD_MESSAGE, | ||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||||
| ModelVisibility) | |||||
| DEFAULT_CREDENTIALS_PATH, | |||||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||||
| Licenses, ModelVisibility) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | from modelscope.hub.errors import (InvalidParameter, NotExistError, | ||||
| NotLoginException, RequestError, | |||||
| datahub_raise_on_error, | |||||
| NotLoginException, NoValidRevisionError, | |||||
| RequestError, datahub_raise_on_error, | |||||
| handle_http_post_error, | handle_http_post_error, | ||||
| handle_http_response, is_ok, raise_on_error) | |||||
| handle_http_response, is_ok, | |||||
| raise_for_http_status, raise_on_error) | |||||
| from modelscope.hub.git import GitCommandWrapper | from modelscope.hub.git import GitCommandWrapper | ||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| from modelscope.hub.utils.utils import (get_endpoint, | |||||
| model_id_to_group_owner_name) | |||||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION, | DEFAULT_MODEL_REVISION, | ||||
| DatasetFormations, DatasetMetaFormats, | |||||
| DownloadMode, ModelFile) | |||||
| DEFAULT_REPOSITORY_REVISION, | |||||
| MASTER_MODEL_BRANCH, DatasetFormations, | |||||
| DatasetMetaFormats, DownloadMode, | |||||
| ModelFile) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| # yapf: enable | |||||
| from .utils.utils import (get_endpoint, get_release_datetime, | |||||
| model_id_to_group_owner_name) | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -46,6 +51,7 @@ class HubApi: | |||||
| def __init__(self, endpoint=None): | def __init__(self, endpoint=None): | ||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | self.endpoint = endpoint if endpoint is not None else get_endpoint() | ||||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||||
| def login( | def login( | ||||
| self, | self, | ||||
| @@ -65,8 +71,9 @@ class HubApi: | |||||
| </Tip> | </Tip> | ||||
| """ | """ | ||||
| path = f'{self.endpoint}/api/v1/login' | path = f'{self.endpoint}/api/v1/login' | ||||
| r = requests.post(path, json={'AccessToken': access_token}) | |||||
| r.raise_for_status() | |||||
| r = requests.post( | |||||
| path, json={'AccessToken': access_token}, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| d = r.json() | d = r.json() | ||||
| raise_on_error(d) | raise_on_error(d) | ||||
| @@ -120,7 +127,8 @@ class HubApi: | |||||
| 'Visibility': visibility, # server check | 'Visibility': visibility, # server check | ||||
| 'License': license | 'License': license | ||||
| } | } | ||||
| r = requests.post(path, json=body, cookies=cookies) | |||||
| r = requests.post( | |||||
| path, json=body, cookies=cookies, headers=self.headers) | |||||
| handle_http_post_error(r, path, body) | handle_http_post_error(r, path, body) | ||||
| raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
| model_repo_url = f'{get_endpoint()}/{model_id}' | model_repo_url = f'{get_endpoint()}/{model_id}' | ||||
| @@ -140,8 +148,8 @@ class HubApi: | |||||
| raise ValueError('Token does not exist, please login first.') | raise ValueError('Token does not exist, please login first.') | ||||
| path = f'{self.endpoint}/api/v1/models/{model_id}' | path = f'{self.endpoint}/api/v1/models/{model_id}' | ||||
| r = requests.delete(path, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| r = requests.delete(path, cookies=cookies, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
| def get_model_url(self, model_id): | def get_model_url(self, model_id): | ||||
| @@ -170,7 +178,7 @@ class HubApi: | |||||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | owner_or_group, name = model_id_to_group_owner_name(model_id) | ||||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | ||||
| r = requests.get(path, cookies=cookies) | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -178,7 +186,7 @@ class HubApi: | |||||
| else: | else: | ||||
| raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| def push_model(self, | def push_model(self, | ||||
| model_id: str, | model_id: str, | ||||
| @@ -187,7 +195,7 @@ class HubApi: | |||||
| license: str = Licenses.APACHE_V2, | license: str = Licenses.APACHE_V2, | ||||
| chinese_name: Optional[str] = None, | chinese_name: Optional[str] = None, | ||||
| commit_message: Optional[str] = 'upload model', | commit_message: Optional[str] = 'upload model', | ||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): | |||||
| """ | """ | ||||
| Upload model from a given directory to given repository. A valid model directory | Upload model from a given directory to given repository. A valid model directory | ||||
| must contain a configuration.json file. | must contain a configuration.json file. | ||||
| @@ -269,7 +277,7 @@ class HubApi: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | ||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | commit_message = '[automsg] push model %s to hub at %s' % ( | ||||
| model_id, date) | model_id, date) | ||||
| repo.push(commit_message=commit_message, branch=revision) | |||||
| repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) | |||||
| except Exception: | except Exception: | ||||
| raise | raise | ||||
| finally: | finally: | ||||
| @@ -294,7 +302,8 @@ class HubApi: | |||||
| path, | path, | ||||
| data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | ||||
| (owner_or_group, page_number, page_size), | (owner_or_group, page_number, page_size), | ||||
| cookies=cookies) | |||||
| cookies=cookies, | |||||
| headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, 'list_model') | handle_http_response(r, logger, cookies, 'list_model') | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -303,7 +312,7 @@ class HubApi: | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| def _check_cookie(self, | def _check_cookie(self, | ||||
| @@ -318,10 +327,70 @@ class HubApi: | |||||
| raise ValueError('Token does not exist, please login first.') | raise ValueError('Token does not exist, please login first.') | ||||
| return cookies | return cookies | ||||
| def list_model_revisions( | |||||
| self, | |||||
| model_id: str, | |||||
| cutoff_timestamp: int = None, | |||||
| use_cookies: Union[bool, CookieJar] = False) -> List[str]: | |||||
| """Get model branch and tags. | |||||
| Args: | |||||
| model_id (str): The model id | |||||
| cutoff_timestamp (int): Tags created before the cutoff will be included. | |||||
| The timestamp is represented by the seconds elasped from the epoch time. | |||||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||||
| will load cookie from local. Defaults to False. | |||||
| Returns: | |||||
| Tuple[List[str], List[str]]: Return list of branch name and tags | |||||
| """ | |||||
| cookies = self._check_cookie(use_cookies) | |||||
| if cutoff_timestamp is None: | |||||
| cutoff_timestamp = get_release_datetime() | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | |||||
| d = r.json() | |||||
| raise_on_error(d) | |||||
| info = d[API_RESPONSE_FIELD_DATA] | |||||
| # tags returned from backend are guaranteed to be ordered by create-time | |||||
| tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||||
| ] if info['RevisionMap']['Tags'] else [] | |||||
| return tags | |||||
| def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): | |||||
| release_timestamp = get_release_datetime() | |||||
| current_timestamp = int(round(datetime.datetime.now().timestamp())) | |||||
| # for active development in library codes (non-release-branches), release_timestamp | |||||
| # is set to be a far-away-time-in-the-future, to ensure that we shall | |||||
| # get the master-HEAD version from model repo by default (when no revision is provided) | |||||
| if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: | |||||
| branches, tags = self.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| if revision is None: | |||||
| revision = MASTER_MODEL_BRANCH | |||||
| logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||||
| if revision not in branches and revision not in tags: | |||||
| raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||||
| else: | |||||
| revisions = self.list_model_revisions( | |||||
| model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||||
| if revision is None: | |||||
| if len(revisions) == 0: | |||||
| raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||||
| # tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||||
| # we shall obtain the latest revision created earlier than release version of this branch | |||||
| revision = revisions[0] | |||||
| logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||||
| else: | |||||
| if revision not in revisions: | |||||
| raise NotExistError( | |||||
| 'The model: %s has no revision: %s !' % (model_id, revision)) | |||||
| return revision | |||||
| def get_model_branches_and_tags( | def get_model_branches_and_tags( | ||||
| self, | self, | ||||
| model_id: str, | model_id: str, | ||||
| use_cookies: Union[bool, CookieJar] = False | |||||
| use_cookies: Union[bool, CookieJar] = False, | |||||
| ) -> Tuple[List[str], List[str]]: | ) -> Tuple[List[str], List[str]]: | ||||
| """Get model branch and tags. | """Get model branch and tags. | ||||
| @@ -335,7 +404,7 @@ class HubApi: | |||||
| cookies = self._check_cookie(use_cookies) | cookies = self._check_cookie(use_cookies) | ||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | ||||
| r = requests.get(path, cookies=cookies) | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| d = r.json() | d = r.json() | ||||
| raise_on_error(d) | raise_on_error(d) | ||||
| @@ -376,7 +445,11 @@ class HubApi: | |||||
| if root is not None: | if root is not None: | ||||
| path = path + f'&Root={root}' | path = path + f'&Root={root}' | ||||
| r = requests.get(path, cookies=cookies, headers=headers) | |||||
| r = requests.get( | |||||
| path, cookies=cookies, headers={ | |||||
| **headers, | |||||
| **self.headers | |||||
| }) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| d = r.json() | d = r.json() | ||||
| @@ -392,10 +465,9 @@ class HubApi: | |||||
| def list_datasets(self): | def list_datasets(self): | ||||
| path = f'{self.endpoint}/api/v1/datasets' | path = f'{self.endpoint}/api/v1/datasets' | ||||
| headers = None | |||||
| params = {} | params = {} | ||||
| r = requests.get(path, params=params, headers=headers) | |||||
| r.raise_for_status() | |||||
| r = requests.get(path, params=params, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | ||||
| return [x['Name'] for x in dataset_list] | return [x['Name'] for x in dataset_list] | ||||
| @@ -425,7 +497,7 @@ class HubApi: | |||||
| dataset_id = resp['Data']['Id'] | dataset_id = resp['Data']['Id'] | ||||
| dataset_type = resp['Data']['Type'] | dataset_type = resp['Data']['Type'] | ||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | ||||
| r = requests.get(datahub_url) | |||||
| r = requests.get(datahub_url, headers=self.headers) | |||||
| resp = r.json() | resp = r.json() | ||||
| datahub_raise_on_error(datahub_url, resp) | datahub_raise_on_error(datahub_url, resp) | ||||
| file_list = resp['Data'] | file_list = resp['Data'] | ||||
| @@ -445,7 +517,7 @@ class HubApi: | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | ||||
| f'Revision={revision}&FilePath={file_path}' | f'Revision={revision}&FilePath={file_path}' | ||||
| r = requests.get(datahub_url) | r = requests.get(datahub_url) | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| local_path = os.path.join(cache_dir, file_path) | local_path = os.path.join(cache_dir, file_path) | ||||
| if os.path.exists(local_path): | if os.path.exists(local_path): | ||||
| logger.warning( | logger.warning( | ||||
| @@ -490,7 +562,8 @@ class HubApi: | |||||
| f'ststoken?Revision={revision}' | f'ststoken?Revision={revision}' | ||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | cookies = requests.utils.dict_from_cookiejar(cookies) | ||||
| r = requests.get(url=datahub_url, cookies=cookies) | |||||
| r = requests.get( | |||||
| url=datahub_url, cookies=cookies, headers=self.headers) | |||||
| resp = r.json() | resp = r.json() | ||||
| raise_on_error(resp) | raise_on_error(resp) | ||||
| return resp['Data'] | return resp['Data'] | ||||
| @@ -512,12 +585,12 @@ class HubApi: | |||||
| def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | ||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | ||||
| r = requests.post(url) | |||||
| r.raise_for_status() | |||||
| r = requests.post(url, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| @staticmethod | @staticmethod | ||||
| def datahub_remote_call(url): | def datahub_remote_call(url): | ||||
| r = requests.get(url) | |||||
| r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||||
| resp = r.json() | resp = r.json() | ||||
| datahub_raise_on_error(url, resp) | datahub_raise_on_error(url, resp) | ||||
| return resp['Data'] | return resp['Data'] | ||||
| @@ -531,6 +604,7 @@ class ModelScopeConfig: | |||||
| COOKIES_FILE_NAME = 'cookies' | COOKIES_FILE_NAME = 'cookies' | ||||
| GIT_TOKEN_FILE_NAME = 'git_token' | GIT_TOKEN_FILE_NAME = 'git_token' | ||||
| USER_INFO_FILE_NAME = 'user' | USER_INFO_FILE_NAME = 'user' | ||||
| USER_SESSION_ID_FILE_NAME = 'session' | |||||
| @staticmethod | @staticmethod | ||||
| def make_sure_credential_path_exist(): | def make_sure_credential_path_exist(): | ||||
| @@ -559,6 +633,23 @@ class ModelScopeConfig: | |||||
| return cookies | return cookies | ||||
| return None | return None | ||||
| @staticmethod | |||||
| def get_user_session_id(): | |||||
| session_path = os.path.join(ModelScopeConfig.path_credential, | |||||
| ModelScopeConfig.USER_SESSION_ID_FILE_NAME) | |||||
| session_id = '' | |||||
| if os.path.exists(session_path): | |||||
| with open(session_path, 'rb') as f: | |||||
| session_id = str(f.readline().strip(), encoding='utf-8') | |||||
| return session_id | |||||
| if session_id == '' or len(session_id) != 32: | |||||
| session_id = str(uuid.uuid4().hex) | |||||
| ModelScopeConfig.make_sure_credential_path_exist() | |||||
| with open(session_path, 'w+') as wf: | |||||
| wf.write(session_id) | |||||
| return session_id | |||||
| @staticmethod | @staticmethod | ||||
| def save_token(token: str): | def save_token(token: str): | ||||
| ModelScopeConfig.make_sure_credential_path_exist() | ModelScopeConfig.make_sure_credential_path_exist() | ||||
| @@ -607,3 +698,32 @@ class ModelScopeConfig: | |||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| pass | pass | ||||
| return token | return token | ||||
| @staticmethod | |||||
| def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||||
| """Formats a user-agent string with basic info about a request. | |||||
| Args: | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user agent info in the form of a dictionary or a single string. | |||||
| Returns: | |||||
| The formatted user-agent string. | |||||
| """ | |||||
| env = 'custom' | |||||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | |||||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||||
| __version__, | |||||
| platform.python_version(), | |||||
| ModelScopeConfig.get_user_session_id(), | |||||
| platform.platform(), | |||||
| platform.processor(), | |||||
| env, | |||||
| ) | |||||
| if isinstance(user_agent, dict): | |||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||||
| elif isinstance(user_agent, str): | |||||
| ua += ';' + user_agent | |||||
| return ua | |||||
| @@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||||
| API_RESPONSE_FIELD_USERNAME = 'Username' | API_RESPONSE_FIELD_USERNAME = 'Username' | ||||
| API_RESPONSE_FIELD_EMAIL = 'Email' | API_RESPONSE_FIELD_EMAIL = 'Email' | ||||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | API_RESPONSE_FIELD_MESSAGE = 'Message' | ||||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||||
| class Licenses(object): | class Licenses(object): | ||||
| @@ -3,7 +3,6 @@ from abc import ABC | |||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from typing import Optional | from typing import Optional | ||||
| import attrs | |||||
| import json | import json | ||||
| import requests | import requests | ||||
| from attrs import asdict, define, field, validators | from attrs import asdict, define, field, validators | ||||
| @@ -12,7 +11,8 @@ from modelscope.hub.api import ModelScopeConfig | |||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | ||||
| API_RESPONSE_FIELD_MESSAGE) | API_RESPONSE_FIELD_MESSAGE) | ||||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | from modelscope.hub.errors import (NotLoginException, NotSupportError, | ||||
| RequestError, handle_http_response, is_ok) | |||||
| RequestError, handle_http_response, is_ok, | |||||
| raise_for_http_status) | |||||
| from modelscope.hub.utils.utils import get_endpoint | from modelscope.hub.utils.utils import get_endpoint | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -188,6 +188,7 @@ class ServiceDeployer(object): | |||||
| def __init__(self, endpoint=None): | def __init__(self, endpoint=None): | ||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | self.endpoint = endpoint if endpoint is not None else get_endpoint() | ||||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||||
| self.cookies = ModelScopeConfig.get_cookies() | self.cookies = ModelScopeConfig.get_cookies() | ||||
| if self.cookies is None: | if self.cookies is None: | ||||
| raise NotLoginException( | raise NotLoginException( | ||||
| @@ -227,12 +228,9 @@ class ServiceDeployer(object): | |||||
| resource=resource, | resource=resource, | ||||
| provider=provider) | provider=provider) | ||||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | path = f'{self.endpoint}/api/v1/deployer/endpoint' | ||||
| body = attrs.asdict(create_params) | |||||
| body = asdict(create_params) | |||||
| r = requests.post( | r = requests.post( | ||||
| path, | |||||
| json=body, | |||||
| cookies=self.cookies, | |||||
| ) | |||||
| path, json=body, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'create_service') | handle_http_response(r, logger, self.cookies, 'create_service') | ||||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -241,7 +239,7 @@ class ServiceDeployer(object): | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | def get(self, instance_name: str, provider: ServiceProviderParameters): | ||||
| @@ -262,7 +260,7 @@ class ServiceDeployer(object): | |||||
| params = GetServiceParameters(provider=provider) | params = GetServiceParameters(provider=provider) | ||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | ||||
| self.endpoint, instance_name, params.to_query_str()) | self.endpoint, instance_name, params.to_query_str()) | ||||
| r = requests.get(path, cookies=self.cookies) | |||||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'get_service') | handle_http_response(r, logger, self.cookies, 'get_service') | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -271,7 +269,7 @@ class ServiceDeployer(object): | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | def delete(self, instance_name: str, provider: ServiceProviderParameters): | ||||
| @@ -293,7 +291,7 @@ class ServiceDeployer(object): | |||||
| params = DeleteServiceParameters(provider=provider) | params = DeleteServiceParameters(provider=provider) | ||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | ||||
| self.endpoint, instance_name, params.to_query_str()) | self.endpoint, instance_name, params.to_query_str()) | ||||
| r = requests.delete(path, cookies=self.cookies) | |||||
| r = requests.delete(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'delete_service') | handle_http_response(r, logger, self.cookies, 'delete_service') | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -302,7 +300,7 @@ class ServiceDeployer(object): | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| def list(self, | def list(self, | ||||
| @@ -328,7 +326,7 @@ class ServiceDeployer(object): | |||||
| provider=provider, skip=skip, limit=limit) | provider=provider, skip=skip, limit=limit) | ||||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | ||||
| params.to_query_str()) | params.to_query_str()) | ||||
| r = requests.get(path, cookies=self.cookies) | |||||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | handle_http_response(r, logger, self.cookies, 'list_service_instances') | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -337,5 +335,5 @@ class ServiceDeployer(object): | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| @@ -13,6 +13,10 @@ class NotSupportError(Exception): | |||||
| pass | pass | ||||
| class NoValidRevisionError(Exception): | |||||
| pass | |||||
| class NotExistError(Exception): | class NotExistError(Exception): | ||||
| pass | pass | ||||
| @@ -99,3 +103,33 @@ def datahub_raise_on_error(url, rsp): | |||||
| raise RequestError( | raise RequestError( | ||||
| f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | ||||
| ) | ) | ||||
| def raise_for_http_status(rsp): | |||||
| """ | |||||
| Attempt to decode utf-8 first since some servers | |||||
| localize reason strings, for invalid utf-8, fall back | |||||
| to decoding with iso-8859-1. | |||||
| """ | |||||
| http_error_msg = '' | |||||
| if isinstance(rsp.reason, bytes): | |||||
| try: | |||||
| reason = rsp.reason.decode('utf-8') | |||||
| except UnicodeDecodeError: | |||||
| reason = rsp.reason.decode('iso-8859-1') | |||||
| else: | |||||
| reason = rsp.reason | |||||
| if 400 <= rsp.status_code < 500: | |||||
| http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, | |||||
| reason, rsp.url) | |||||
| elif 500 <= rsp.status_code < 600: | |||||
| http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, | |||||
| reason, rsp.url) | |||||
| if http_error_msg: | |||||
| req = rsp.request | |||||
| if req.method == 'POST': | |||||
| http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) | |||||
| raise HTTPError(http_error_msg, response=rsp) | |||||
| @@ -2,13 +2,11 @@ | |||||
| import copy | import copy | ||||
| import os | import os | ||||
| import sys | |||||
| import tempfile | import tempfile | ||||
| from functools import partial | from functools import partial | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from uuid import uuid4 | |||||
| import requests | import requests | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| @@ -23,7 +21,6 @@ from .utils.caching import ModelFileSystemCache | |||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | from .utils.utils import (file_integrity_validation, get_cache_dir, | ||||
| get_endpoint, model_id_to_group_owner_name) | get_endpoint, model_id_to_group_owner_name) | ||||
| SESSION_ID = uuid4().hex | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -34,6 +31,7 @@ def model_file_download( | |||||
| cache_dir: Optional[str] = None, | cache_dir: Optional[str] = None, | ||||
| user_agent: Union[Dict, str, None] = None, | user_agent: Union[Dict, str, None] = None, | ||||
| local_files_only: Optional[bool] = False, | local_files_only: Optional[bool] = False, | ||||
| cookies: Optional[CookieJar] = None, | |||||
| ) -> Optional[str]: # pragma: no cover | ) -> Optional[str]: # pragma: no cover | ||||
| """ | """ | ||||
| Download from a given URL and cache it if it's not already present in the | Download from a given URL and cache it if it's not already present in the | ||||
| @@ -104,54 +102,47 @@ def model_file_download( | |||||
| " online, set 'local_files_only' to False.") | " online, set 'local_files_only' to False.") | ||||
| _api = HubApi() | _api = HubApi() | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| headers = { | |||||
| 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||||
| } | |||||
| if cookies is None: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| revision = _api.get_valid_revision( | |||||
| model_id, revision=revision, cookies=cookies) | |||||
| file_to_download_info = None | file_to_download_info = None | ||||
| is_commit_id = False | |||||
| if revision in branches or revision in tags: # The revision is version or tag, | |||||
| # we need to confirm the version is up to date | |||||
| # we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| use_cookies=False if cookies is None else cookies) | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| if model_file['Path'] == file_path: | |||||
| if cache.exists(model_file): | |||||
| return cache.get_file_by_info(model_file) | |||||
| else: | |||||
| file_to_download_info = model_file | |||||
| break | |||||
| if file_to_download_info is None: | |||||
| raise NotExistError('The file path: %s not exist in: %s' % | |||||
| (file_path, model_id)) | |||||
| else: # the revision is commit id. | |||||
| cached_file_path = cache.get_file_by_path_and_commit_id( | |||||
| file_path, revision) | |||||
| if cached_file_path is not None: | |||||
| file_name = os.path.basename(cached_file_path) | |||||
| logger.info( | |||||
| f'File {file_name} already in cache, skip downloading!') | |||||
| return cached_file_path # the file is in cache. | |||||
| is_commit_id = True | |||||
| # we need to confirm the version is up-to-date | |||||
| # we need to get the file list to check if the latest version is cached, if so return, otherwise download | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| use_cookies=False if cookies is None else cookies) | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| if model_file['Path'] == file_path: | |||||
| if cache.exists(model_file): | |||||
| logger.info( | |||||
| f'File {model_file["Name"]} already in cache, skip downloading!' | |||||
| ) | |||||
| return cache.get_file_by_info(model_file) | |||||
| else: | |||||
| file_to_download_info = model_file | |||||
| break | |||||
| if file_to_download_info is None: | |||||
| raise NotExistError('The file path: %s not exist in: %s' % | |||||
| (file_path, model_id)) | |||||
| # we need to download again | # we need to download again | ||||
| url_to_download = get_file_download_url(model_id, file_path, revision) | url_to_download = get_file_download_url(model_id, file_path, revision) | ||||
| file_to_download_info = { | file_to_download_info = { | ||||
| 'Path': | |||||
| file_path, | |||||
| 'Revision': | |||||
| revision if is_commit_id else file_to_download_info['Revision'], | |||||
| FILE_HASH: | |||||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||||
| file_to_download_info[FILE_HASH] | |||||
| 'Path': file_path, | |||||
| 'Revision': file_to_download_info['Revision'], | |||||
| FILE_HASH: file_to_download_info[FILE_HASH] | |||||
| } | } | ||||
| temp_file_name = next(tempfile._get_candidate_names()) | temp_file_name = next(tempfile._get_candidate_names()) | ||||
| @@ -170,25 +161,6 @@ def model_file_download( | |||||
| os.path.join(temporary_cache_dir, temp_file_name)) | os.path.join(temporary_cache_dir, temp_file_name)) | ||||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||||
| """Formats a user-agent string with basic info about a request. | |||||
| Args: | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user agent info in the form of a dictionary or a single string. | |||||
| Returns: | |||||
| The formatted user-agent string. | |||||
| """ | |||||
| ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||||
| if isinstance(user_agent, dict): | |||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||||
| elif isinstance(user_agent, str): | |||||
| ua = user_agent | |||||
| return ua | |||||
| def get_file_download_url(model_id: str, file_path: str, revision: str): | def get_file_download_url(model_id: str, file_path: str, revision: str): | ||||
| """ | """ | ||||
| Format file download url according to `model_id`, `revision` and `file_path`. | Format file download url according to `model_id`, `revision` and `file_path`. | ||||
| @@ -5,6 +5,7 @@ import subprocess | |||||
| from typing import List | from typing import List | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..utils.constant import MASTER_MODEL_BRANCH | |||||
| from .errors import GitError | from .errors import GitError | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -227,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| files.append(line.split(' ')[-1]) | files.append(line.split(' ')[-1]) | ||||
| return files | return files | ||||
| def tag(self, | |||||
| repo_dir: str, | |||||
| tag_name: str, | |||||
| message: str, | |||||
| ref: str = MASTER_MODEL_BRANCH): | |||||
| cmd_args = [ | |||||
| '-C', repo_dir, 'tag', tag_name, '-m', | |||||
| '"%s"' % message, ref | |||||
| ] | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| logger.debug(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| def push_tag(self, repo_dir: str, tag_name): | |||||
| cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| logger.debug(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| @@ -5,7 +5,8 @@ from typing import Optional | |||||
| from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION) | |||||
| DEFAULT_REPOSITORY_REVISION, | |||||
| MASTER_MODEL_BRANCH) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .git import GitCommandWrapper | from .git import GitCommandWrapper | ||||
| from .utils.utils import get_endpoint | from .utils.utils import get_endpoint | ||||
| @@ -20,7 +21,7 @@ class Repository: | |||||
| def __init__(self, | def __init__(self, | ||||
| model_dir: str, | model_dir: str, | ||||
| clone_from: str, | clone_from: str, | ||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| auth_token: Optional[str] = None, | auth_token: Optional[str] = None, | ||||
| git_path: Optional[str] = None): | git_path: Optional[str] = None): | ||||
| """ | """ | ||||
| @@ -89,7 +90,8 @@ class Repository: | |||||
| def push(self, | def push(self, | ||||
| commit_message: str, | commit_message: str, | ||||
| branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| force: bool = False): | force: bool = False): | ||||
| """Push local files to remote, this method will do. | """Push local files to remote, this method will do. | ||||
| git pull | git pull | ||||
| @@ -116,14 +118,48 @@ class Repository: | |||||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | url = self.git_wrapper.get_repo_remote_url(self.model_dir) | ||||
| self.git_wrapper.pull(self.model_dir) | self.git_wrapper.pull(self.model_dir) | ||||
| self.git_wrapper.add(self.model_dir, all_files=True) | self.git_wrapper.add(self.model_dir, all_files=True) | ||||
| self.git_wrapper.commit(self.model_dir, commit_message) | self.git_wrapper.commit(self.model_dir, commit_message) | ||||
| self.git_wrapper.push( | self.git_wrapper.push( | ||||
| repo_dir=self.model_dir, | repo_dir=self.model_dir, | ||||
| token=self.auth_token, | token=self.auth_token, | ||||
| url=url, | url=url, | ||||
| local_branch=branch, | |||||
| remote_branch=branch) | |||||
| local_branch=local_branch, | |||||
| remote_branch=remote_branch) | |||||
| def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): | |||||
| """Create a new tag. | |||||
| Args: | |||||
| tag_name (str): The name of the tag | |||||
| message (str): The tag message. | |||||
| ref (str): The tag reference, can be commit id or branch. | |||||
| """ | |||||
| if tag_name is None or tag_name == '': | |||||
| msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' | |||||
| raise InvalidParameter(msg) | |||||
| if message is None or message == '': | |||||
| msg = 'We use annotated tag, therefore message cannot None or empty.' | |||||
| self.git_wrapper.tag( | |||||
| repo_dir=self.model_dir, | |||||
| tag_name=tag_name, | |||||
| message=message, | |||||
| ref=ref) | |||||
| def tag_and_push(self, | |||||
| tag_name: str, | |||||
| message: str, | |||||
| ref: str = MASTER_MODEL_BRANCH): | |||||
| """Create tag and push to remote | |||||
| Args: | |||||
| tag_name (str): The name of the tag | |||||
| message (str): The tag message. | |||||
| ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. | |||||
| """ | |||||
| self.tag(tag_name, message, ref) | |||||
| self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) | |||||
| class DatasetRepository: | class DatasetRepository: | ||||
| @@ -2,6 +2,7 @@ | |||||
| import os | import os | ||||
| import tempfile | import tempfile | ||||
| from http.cookiejar import CookieJar | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| @@ -9,9 +10,7 @@ from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import NotExistError | |||||
| from .file_download import (get_file_download_url, http_get_file, | |||||
| http_user_agent) | |||||
| from .file_download import get_file_download_url, http_get_file | |||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | from .utils.utils import (file_integrity_validation, get_cache_dir, | ||||
| model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
| @@ -23,7 +22,8 @@ def snapshot_download(model_id: str, | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | revision: Optional[str] = DEFAULT_MODEL_REVISION, | ||||
| cache_dir: Union[str, Path, None] = None, | cache_dir: Union[str, Path, None] = None, | ||||
| user_agent: Optional[Union[Dict, str]] = None, | user_agent: Optional[Union[Dict, str]] = None, | ||||
| local_files_only: Optional[bool] = False) -> str: | |||||
| local_files_only: Optional[bool] = False, | |||||
| cookies: Optional[CookieJar] = None) -> str: | |||||
| """Download all files of a repo. | """Download all files of a repo. | ||||
| Downloads a whole snapshot of a repo's files at the specified revision. This | Downloads a whole snapshot of a repo's files at the specified revision. This | ||||
| is useful when you want all files from a repo, because you don't know which | is useful when you want all files from a repo, because you don't know which | ||||
| @@ -81,15 +81,15 @@ def snapshot_download(model_id: str, | |||||
| ) # we can not confirm the cached file is for snapshot 'revision' | ) # we can not confirm the cached file is for snapshot 'revision' | ||||
| else: | else: | ||||
| # make headers | # make headers | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| headers = { | |||||
| 'user-agent': | |||||
| ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||||
| } | |||||
| _api = HubApi() | _api = HubApi() | ||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| # get file list from model repo | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| if revision not in branches and revision not in tags: | |||||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||||
| % revision) | |||||
| if cookies is None: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| revision = _api.get_valid_revision( | |||||
| model_id, revision=revision, cookies=cookies) | |||||
| snapshot_header = headers if 'CI_TEST' in os.environ else { | snapshot_header = headers if 'CI_TEST' in os.environ else { | ||||
| **headers, | **headers, | ||||
| @@ -110,7 +110,7 @@ def snapshot_download(model_id: str, | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| continue | continue | ||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||||
| # check model_file is exist in cache, if existed, skip download, otherwise download | |||||
| if cache.exists(model_file): | if cache.exists(model_file): | ||||
| file_name = os.path.basename(model_file['Name']) | file_name = os.path.basename(model_file['Name']) | ||||
| logger.info( | logger.info( | ||||
| @@ -2,11 +2,12 @@ | |||||
| import hashlib | import hashlib | ||||
| import os | import os | ||||
| from datetime import datetime | |||||
| from typing import Optional | from typing import Optional | ||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | ||||
| DEFAULT_MODELSCOPE_GROUP, | DEFAULT_MODELSCOPE_GROUP, | ||||
| MODEL_ID_SEPARATOR, | |||||
| MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||||
| MODELSCOPE_URL_SCHEME) | MODELSCOPE_URL_SCHEME) | ||||
| from modelscope.hub.errors import FileIntegrityError | from modelscope.hub.errors import FileIntegrityError | ||||
| from modelscope.utils.file_utils import get_default_cache_dir | from modelscope.utils.file_utils import get_default_cache_dir | ||||
| @@ -37,6 +38,18 @@ def get_cache_dir(model_id: Optional[str] = None): | |||||
| base_path, model_id + '/') | base_path, model_id + '/') | ||||
| def get_release_datetime(): | |||||
| if MODELSCOPE_SDK_DEBUG in os.environ: | |||||
| rt = int(round(datetime.now().timestamp())) | |||||
| else: | |||||
| from modelscope import version | |||||
| rt = int( | |||||
| round( | |||||
| datetime.strptime(version.__release_datetime__, | |||||
| '%Y-%m-%d %H:%M:%S').timestamp())) | |||||
| return rt | |||||
| def get_endpoint(): | def get_endpoint(): | ||||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | ||||
| DEFAULT_MODELSCOPE_DOMAIN) | DEFAULT_MODELSCOPE_DOMAIN) | ||||
| @@ -675,4 +675,8 @@ class MsDataset: | |||||
| revision=revision, | revision=revision, | ||||
| auth_token=auth_token, | auth_token=auth_token, | ||||
| git_path=git_path) | git_path=git_path) | ||||
| _repo.push(commit_message=commit_message, branch=revision, force=force) | |||||
| _repo.push( | |||||
| commit_message=commit_message, | |||||
| local_branch=revision, | |||||
| remote_branch=revision, | |||||
| force=force) | |||||
| @@ -311,7 +311,9 @@ class Frameworks(object): | |||||
| kaldi = 'kaldi' | kaldi = 'kaldi' | ||||
| DEFAULT_MODEL_REVISION = 'master' | |||||
| DEFAULT_MODEL_REVISION = None | |||||
| MASTER_MODEL_BRANCH = 'master' | |||||
| DEFAULT_REPOSITORY_REVISION = 'master' | |||||
| DEFAULT_DATASET_REVISION = 'master' | DEFAULT_DATASET_REVISION = 'master' | ||||
| DEFAULT_DATASET_NAMESPACE = 'modelscope' | DEFAULT_DATASET_NAMESPACE = 'modelscope' | ||||
| @@ -1 +1,5 @@ | |||||
| # Make sure to modify __release_datetime__ to release time when making official release. | |||||
| __version__ = '0.5.0' | __version__ = '0.5.0' | ||||
| # default release datetime for branches under active development is set | |||||
| # to be a time far-far-away-into-the-future | |||||
| __release_datetime__ = '2099-10-13 08:56:12' | |||||
| @@ -25,10 +25,10 @@ class HubOperationTest(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| self.api = HubApi() | self.api = HubApi() | ||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_name = 'op-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
| self.revision = 'v0.1_test_revision' | |||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PUBLIC, | visibility=ModelVisibility.PUBLIC, | ||||
| @@ -46,6 +46,7 @@ class HubOperationTest(unittest.TestCase): | |||||
| os.system("echo 'testtest'>%s" | os.system("echo 'testtest'>%s" | ||||
| % os.path.join(self.model_dir, download_model_file_name)) | % os.path.join(self.model_dir, download_model_file_name)) | ||||
| repo.push('add model') | repo.push('add model') | ||||
| repo.tag_and_push(self.revision, 'Test revision') | |||||
| def test_model_repo_creation(self): | def test_model_repo_creation(self): | ||||
| # change to proper model names before use | # change to proper model names before use | ||||
| @@ -61,7 +62,9 @@ class HubOperationTest(unittest.TestCase): | |||||
| def test_download_single_file(self): | def test_download_single_file(self): | ||||
| self.prepare_case() | self.prepare_case() | ||||
| downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
| model_id=self.model_id, file_path=download_model_file_name) | |||||
| model_id=self.model_id, | |||||
| file_path=download_model_file_name, | |||||
| revision=self.revision) | |||||
| assert os.path.exists(downloaded_file) | assert os.path.exists(downloaded_file) | ||||
| mdtime1 = os.path.getmtime(downloaded_file) | mdtime1 = os.path.getmtime(downloaded_file) | ||||
| # download again | # download again | ||||
| @@ -78,17 +81,16 @@ class HubOperationTest(unittest.TestCase): | |||||
| assert os.path.exists(downloaded_file_path) | assert os.path.exists(downloaded_file_path) | ||||
| mdtime1 = os.path.getmtime(downloaded_file_path) | mdtime1 = os.path.getmtime(downloaded_file_path) | ||||
| # download again | # download again | ||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| snapshot_path = snapshot_download( | |||||
| model_id=self.model_id, revision=self.revision) | |||||
| mdtime2 = os.path.getmtime(downloaded_file_path) | mdtime2 = os.path.getmtime(downloaded_file_path) | ||||
| assert mdtime1 == mdtime2 | assert mdtime1 == mdtime2 | ||||
| model_file_download( | |||||
| model_id=self.model_id, | |||||
| file_path=download_model_file_name) # not add counter | |||||
| def test_download_public_without_login(self): | def test_download_public_without_login(self): | ||||
| self.prepare_case() | self.prepare_case() | ||||
| rmtree(ModelScopeConfig.path_credential) | rmtree(ModelScopeConfig.path_credential) | ||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| snapshot_path = snapshot_download( | |||||
| model_id=self.model_id, revision=self.revision) | |||||
| downloaded_file_path = os.path.join(snapshot_path, | downloaded_file_path = os.path.join(snapshot_path, | ||||
| download_model_file_name) | download_model_file_name) | ||||
| assert os.path.exists(downloaded_file_path) | assert os.path.exists(downloaded_file_path) | ||||
| @@ -96,26 +98,38 @@ class HubOperationTest(unittest.TestCase): | |||||
| downloaded_file = model_file_download( | downloaded_file = model_file_download( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| file_path=download_model_file_name, | file_path=download_model_file_name, | ||||
| revision=self.revision, | |||||
| cache_dir=temporary_dir) | cache_dir=temporary_dir) | ||||
| assert os.path.exists(downloaded_file) | assert os.path.exists(downloaded_file) | ||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| def test_snapshot_delete_download_cache_file(self): | def test_snapshot_delete_download_cache_file(self): | ||||
| self.prepare_case() | self.prepare_case() | ||||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||||
| snapshot_path = snapshot_download( | |||||
| model_id=self.model_id, revision=self.revision) | |||||
| downloaded_file_path = os.path.join(snapshot_path, | downloaded_file_path = os.path.join(snapshot_path, | ||||
| download_model_file_name) | download_model_file_name) | ||||
| assert os.path.exists(downloaded_file_path) | assert os.path.exists(downloaded_file_path) | ||||
| os.remove(downloaded_file_path) | os.remove(downloaded_file_path) | ||||
| # download again in cache | # download again in cache | ||||
| file_download_path = model_file_download( | file_download_path = model_file_download( | ||||
| model_id=self.model_id, file_path=ModelFile.README) | |||||
| model_id=self.model_id, | |||||
| file_path=ModelFile.README, | |||||
| revision=self.revision) | |||||
| assert os.path.exists(file_download_path) | assert os.path.exists(file_download_path) | ||||
| # deleted file need download again | # deleted file need download again | ||||
| file_download_path = model_file_download( | file_download_path = model_file_download( | ||||
| model_id=self.model_id, file_path=download_model_file_name) | |||||
| model_id=self.model_id, | |||||
| file_path=download_model_file_name, | |||||
| revision=self.revision) | |||||
| assert os.path.exists(file_download_path) | assert os.path.exists(file_download_path) | ||||
| def test_snapshot_download_default_revision(self): | |||||
| pass # TOTO | |||||
| def test_file_download_default_revision(self): | |||||
| pass # TODO | |||||
| def get_model_download_times(self): | def get_model_download_times(self): | ||||
| url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' | url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' | ||||
| cookies = ModelScopeConfig.get_cookies() | cookies = ModelScopeConfig.get_cookies() | ||||
| @@ -17,23 +17,34 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||||
| TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | ||||
| delete_credential) | delete_credential) | ||||
| download_model_file_name = 'test.bin' | |||||
| class HubPrivateFileDownloadTest(unittest.TestCase): | class HubPrivateFileDownloadTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
| self.api = HubApi() | self.api = HubApi() | ||||
| # note this is temporary before official account management is ready | |||||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | ||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_name = 'pf-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
| self.revision = 'v0.1_test_revision' | |||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||||
| visibility=ModelVisibility.PRIVATE, | |||||
| license=Licenses.APACHE_V2, | license=Licenses.APACHE_V2, | ||||
| chinese_name=TEST_MODEL_CHINESE_NAME, | chinese_name=TEST_MODEL_CHINESE_NAME, | ||||
| ) | ) | ||||
| def prepare_case(self): | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name)) | |||||
| repo.push('add model') | |||||
| repo.tag_and_push(self.revision, 'Test revision') | |||||
| def tearDown(self): | def tearDown(self): | ||||
| # credential may deleted or switch login name, we need re-login here | # credential may deleted or switch login name, we need re-login here | ||||
| # to ensure the temporary model is deleted. | # to ensure the temporary model is deleted. | ||||
| @@ -42,49 +53,67 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||||
| self.api.delete_model(model_id=self.model_id) | self.api.delete_model(model_id=self.model_id) | ||||
| def test_snapshot_download_private_model(self): | def test_snapshot_download_private_model(self): | ||||
| snapshot_path = snapshot_download(self.model_id) | |||||
| self.prepare_case() | |||||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | ||||
| def test_snapshot_download_private_model_no_permission(self): | def test_snapshot_download_private_model_no_permission(self): | ||||
| self.prepare_case() | |||||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | ||||
| with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
| snapshot_download(self.model_id) | |||||
| snapshot_download(self.model_id, self.revision) | |||||
| def test_snapshot_download_private_model_without_login(self): | def test_snapshot_download_private_model_without_login(self): | ||||
| self.prepare_case() | |||||
| delete_credential() | delete_credential() | ||||
| with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
| snapshot_download(self.model_id) | |||||
| snapshot_download(self.model_id, self.revision) | |||||
| def test_download_file_private_model(self): | def test_download_file_private_model(self): | ||||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||||
| self.prepare_case() | |||||
| file_path = model_file_download(self.model_id, ModelFile.README, | |||||
| self.revision) | |||||
| assert os.path.exists(file_path) | assert os.path.exists(file_path) | ||||
| def test_download_file_private_model_no_permission(self): | def test_download_file_private_model_no_permission(self): | ||||
| self.prepare_case() | |||||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | ||||
| with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
| model_file_download(self.model_id, ModelFile.README) | |||||
| model_file_download(self.model_id, ModelFile.README, self.revision) | |||||
| def test_download_file_private_model_without_login(self): | def test_download_file_private_model_without_login(self): | ||||
| self.prepare_case() | |||||
| delete_credential() | delete_credential() | ||||
| with self.assertRaises(HTTPError): | with self.assertRaises(HTTPError): | ||||
| model_file_download(self.model_id, ModelFile.README) | |||||
| model_file_download(self.model_id, ModelFile.README, self.revision) | |||||
| def test_snapshot_download_local_only(self): | def test_snapshot_download_local_only(self): | ||||
| self.prepare_case() | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| snapshot_download(self.model_id, local_files_only=True) | |||||
| snapshot_path = snapshot_download(self.model_id) | |||||
| snapshot_download( | |||||
| self.model_id, self.revision, local_files_only=True) | |||||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | ||||
| snapshot_path = snapshot_download(self.model_id, local_files_only=True) | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, self.revision, local_files_only=True) | |||||
| assert os.path.exists(snapshot_path) | assert os.path.exists(snapshot_path) | ||||
| def test_file_download_local_only(self): | def test_file_download_local_only(self): | ||||
| self.prepare_case() | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| model_file_download( | model_file_download( | ||||
| self.model_id, ModelFile.README, local_files_only=True) | |||||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||||
| self.model_id, | |||||
| ModelFile.README, | |||||
| self.revision, | |||||
| local_files_only=True) | |||||
| file_path = model_file_download(self.model_id, ModelFile.README, | |||||
| self.revision) | |||||
| assert os.path.exists(file_path) | assert os.path.exists(file_path) | ||||
| file_path = model_file_download( | file_path = model_file_download( | ||||
| self.model_id, ModelFile.README, local_files_only=True) | |||||
| self.model_id, | |||||
| ModelFile.README, | |||||
| revision=self.revision, | |||||
| local_files_only=True) | |||||
| assert os.path.exists(file_path) | assert os.path.exists(file_path) | ||||
| @@ -21,13 +21,12 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
| self.api = HubApi() | self.api = HubApi() | ||||
| # note this is temporary before official account management is ready | |||||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | ||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_name = 'pr-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||||
| visibility=ModelVisibility.PRIVATE, | |||||
| license=Licenses.APACHE_V2, | license=Licenses.APACHE_V2, | ||||
| chinese_name=TEST_MODEL_CHINESE_NAME, | chinese_name=TEST_MODEL_CHINESE_NAME, | ||||
| ) | ) | ||||
| @@ -22,6 +22,7 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||||
| logger = get_logger() | logger = get_logger() | ||||
| logger.setLevel('DEBUG') | logger.setLevel('DEBUG') | ||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| download_model_file_name = 'test.bin' | |||||
| class HubRepositoryTest(unittest.TestCase): | class HubRepositoryTest(unittest.TestCase): | ||||
| @@ -29,13 +30,13 @@ class HubRepositoryTest(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| self.old_cwd = os.getcwd() | self.old_cwd = os.getcwd() | ||||
| self.api = HubApi() | self.api = HubApi() | ||||
| # note this is temporary before official account management is ready | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
| self.model_name = uuid.uuid4().hex | |||||
| self.model_name = 'repo-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | ||||
| self.revision = 'v0.1_test_revision' | |||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| visibility=ModelVisibility.PUBLIC, # 1-private, 5-public | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2, | license=Licenses.APACHE_V2, | ||||
| chinese_name=TEST_MODEL_CHINESE_NAME, | chinese_name=TEST_MODEL_CHINESE_NAME, | ||||
| ) | ) | ||||
| @@ -67,9 +68,10 @@ class HubRepositoryTest(unittest.TestCase): | |||||
| os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) | os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) | ||||
| os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) | os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) | ||||
| repo.push('test') | repo.push('test') | ||||
| add1 = model_file_download(self.model_id, 'add1.py') | |||||
| repo.tag_and_push(self.revision, 'Test revision') | |||||
| add1 = model_file_download(self.model_id, 'add1.py', self.revision) | |||||
| assert os.path.exists(add1) | assert os.path.exists(add1) | ||||
| add2 = model_file_download(self.model_id, 'add2.py') | |||||
| add2 = model_file_download(self.model_id, 'add2.py', self.revision) | |||||
| assert os.path.exists(add2) | assert os.path.exists(add2) | ||||
| # check lfs files. | # check lfs files. | ||||
| git_wrapper = GitCommandWrapper() | git_wrapper = GitCommandWrapper() | ||||
| @@ -0,0 +1,145 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import tempfile | |||||
| import unittest | |||||
| import uuid | |||||
| from datetime import datetime | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.errors import NotExistError, NoValidRevisionError | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||||
| TEST_MODEL_ORG) | |||||
| logger = get_logger() | |||||
| logger.setLevel('DEBUG') | |||||
| download_model_file_name = 'test.bin' | |||||
| download_model_file_name2 = 'test2.bin' | |||||
| class HubRevisionTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = HubApi() | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | |||||
| self.model_name = 'rv-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||||
| self.revision = 'v0.1_test_revision' | |||||
| self.revision2 = 'v0.2_test_revision' | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||||
| ) | |||||
| def tearDown(self): | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def prepare_repo_data(self): | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
| self.repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name)) | |||||
| self.repo.push('add model') | |||||
| self.repo.tag_and_push(self.revision, 'Test revision') | |||||
| def test_no_tag(self): | |||||
| with self.assertRaises(NoValidRevisionError): | |||||
| snapshot_download(self.model_id, None) | |||||
| with self.assertRaises(NoValidRevisionError): | |||||
| model_file_download(self.model_id, ModelFile.README) | |||||
| def test_with_only_one_tag(self): | |||||
| self.prepare_repo_data() | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, ModelFile.README, cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| def add_new_file_and_tag(self): | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||||
| self.repo.push('add new file') | |||||
| self.repo.tag_and_push(self.revision2, 'Test revision') | |||||
| def test_snapshot_download_different_revision(self): | |||||
| self.prepare_repo_data() | |||||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('First time stamp: %s' % t1) | |||||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| self.add_new_file_and_tag() | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, | |||||
| revision=self.revision, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| assert not os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name2)) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, | |||||
| revision=self.revision2, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name2)) | |||||
| def test_file_download_different_revision(self): | |||||
| self.prepare_repo_data() | |||||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('First time stamp: %s' % t1) | |||||
| file_path = model_file_download(self.model_id, | |||||
| download_model_file_name, | |||||
| self.revision) | |||||
| assert os.path.exists(file_path) | |||||
| self.add_new_file_and_tag() | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| revision=self.revision, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| with self.assertRaises(NotExistError): | |||||
| model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name2, | |||||
| revision=self.revision, | |||||
| cache_dir=temp_cache_dir) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| revision=self.revision2, | |||||
| cache_dir=temp_cache_dir) | |||||
| print('Downloaded file path: %s' % file_path) | |||||
| assert os.path.exists(file_path) | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name2, | |||||
| revision=self.revision2, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,190 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import tempfile | |||||
| import time | |||||
| import unittest | |||||
| import uuid | |||||
| from datetime import datetime | |||||
| from unittest import mock | |||||
| from modelscope import version | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses, | |||||
| ModelVisibility) | |||||
| from modelscope.hub.errors import NotExistError | |||||
| from modelscope.hub.file_download import model_file_download | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||||
| TEST_MODEL_ORG) | |||||
| logger = get_logger() | |||||
| logger.setLevel('DEBUG') | |||||
| download_model_file_name = 'test.bin' | |||||
| download_model_file_name2 = 'test2.bin' | |||||
| class HubRevisionTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.api = HubApi() | |||||
| self.api.login(TEST_ACCESS_TOKEN1) | |||||
| self.model_name = 'rvr-%s' % (uuid.uuid4().hex) | |||||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||||
| self.revision = 'v0.1_test_revision' | |||||
| self.revision2 = 'v0.2_test_revision' | |||||
| self.api.create_model( | |||||
| model_id=self.model_id, | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2, | |||||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||||
| ) | |||||
| names_to_remove = {MODELSCOPE_SDK_DEBUG} | |||||
| self.modified_environ = { | |||||
| k: v | |||||
| for k, v in os.environ.items() if k not in names_to_remove | |||||
| } | |||||
| def tearDown(self): | |||||
| self.api.delete_model(model_id=self.model_id) | |||||
| def prepare_repo_data(self): | |||||
| temporary_dir = tempfile.mkdtemp() | |||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||||
| self.repo = Repository(self.model_dir, clone_from=self.model_id) | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name)) | |||||
| self.repo.push('add model') | |||||
| def prepare_repo_data_and_tag(self): | |||||
| self.prepare_repo_data() | |||||
| self.repo.tag_and_push(self.revision, 'Test revision') | |||||
| def add_new_file_and_tag_to_repo(self): | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||||
| self.repo.push('add new file') | |||||
| self.repo.tag_and_push(self.revision2, 'Test revision') | |||||
| def add_new_file_and_branch_to_repo(self, branch_name): | |||||
| os.system("echo 'testtest'>%s" | |||||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||||
| self.repo.push('add new file', remote_branch=branch_name) | |||||
| def test_dev_mode_default_master(self): | |||||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||||
| self.prepare_repo_data() # no tag, default get master | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| def test_dev_mode_specify_branch(self): | |||||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||||
| self.prepare_repo_data() # no tag, default get master | |||||
| branch_name = 'test' | |||||
| self.add_new_file_and_branch_to_repo(branch_name) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, | |||||
| revision=branch_name, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| revision=branch_name, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| def test_snapshot_download_revision(self): | |||||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||||
| self.prepare_repo_data_and_tag() | |||||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('First time: %s' % t1) | |||||
| time.sleep(10) | |||||
| self.add_new_file_and_tag_to_repo() | |||||
| t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('Secnod time: %s' % t2) | |||||
| # set | |||||
| release_datetime_backup = version.__release_datetime__ | |||||
| logger.info('Origin __release_datetime__: %s' | |||||
| % version.__release_datetime__) | |||||
| try: | |||||
| logger.info('Setting __release_datetime__ to: %s' % t1) | |||||
| version.__release_datetime__ = t1 | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| assert not os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name2)) | |||||
| version.__release_datetime__ = t2 | |||||
| logger.info('Setting __release_datetime__ to: %s' % t2) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| snapshot_path = snapshot_download( | |||||
| self.model_id, cache_dir=temp_cache_dir) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name)) | |||||
| assert os.path.exists( | |||||
| os.path.join(snapshot_path, download_model_file_name2)) | |||||
| finally: | |||||
| version.__release_datetime__ = release_datetime_backup | |||||
| def test_file_download_revision(self): | |||||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||||
| self.prepare_repo_data_and_tag() | |||||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('First time stamp: %s' % t1) | |||||
| time.sleep(10) | |||||
| self.add_new_file_and_tag_to_repo() | |||||
| t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||||
| logger.info('Second time: %s' % t2) | |||||
| release_datetime_backup = version.__release_datetime__ | |||||
| logger.info('Origin __release_datetime__: %s' | |||||
| % version.__release_datetime__) | |||||
| try: | |||||
| version.__release_datetime__ = t1 | |||||
| logger.info('Setting __release_datetime__ to: %s' % t1) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| with self.assertRaises(NotExistError): | |||||
| model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name2, | |||||
| cache_dir=temp_cache_dir) | |||||
| version.__release_datetime__ = t2 | |||||
| logger.info('Setting __release_datetime__ to: %s' % t2) | |||||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name, | |||||
| cache_dir=temp_cache_dir) | |||||
| print('Downloaded file path: %s' % file_path) | |||||
| assert os.path.exists(file_path) | |||||
| file_path = model_file_download( | |||||
| self.model_id, | |||||
| download_model_file_name2, | |||||
| cache_dir=temp_cache_dir) | |||||
| assert os.path.exists(file_path) | |||||
| finally: | |||||
| version.__release_datetime__ = release_datetime_backup | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||