| @@ -1,22 +1,28 @@ | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| pip install -r requirements/tests.txt | |||
| echo "Testing envs" | |||
| printenv | |||
| echo "ENV END" | |||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| pip install -r requirements/tests.txt | |||
| git config --global --add safe.directory /Maas-lib | |||
| git config --global --add safe.directory /Maas-lib | |||
| # linter test | |||
| # use internal project for pre-commit due to the network problem | |||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
| if [ $? -ne 0 ]; then | |||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
| exit -1 | |||
| # linter test | |||
| # use internal project for pre-commit due to the network problem | |||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
| if [ $? -ne 0 ]; then | |||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
| exit -1 | |||
| fi | |||
| # test with install | |||
| python setup.py install | |||
| else | |||
| echo "Running case in release image, run case directly!" | |||
| fi | |||
| # test with install | |||
| python setup.py install | |||
| if [ $# -eq 0 ]; then | |||
| ci_command="python tests/run.py --subprocess" | |||
| else | |||
| @@ -20,28 +20,52 @@ do | |||
| # pull image if there are update | |||
| docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e MODELSCOPE_SDK_DEBUG=True \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| else | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| fi | |||
| if [ $? -ne 0 ]; then | |||
| echo "Running test case failed, please check the log!" | |||
| exit -1 | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .version import __version__ | |||
| from .version import __release_datetime__, __version__ | |||
| __all__ = ['__version__'] | |||
| __all__ = ['__version__', '__release_datetime__'] | |||
| @@ -4,40 +4,45 @@ | |||
| import datetime | |||
| import os | |||
| import pickle | |||
| import platform | |||
| import shutil | |||
| import tempfile | |||
| import uuid | |||
| from collections import defaultdict | |||
| from http import HTTPStatus | |||
| from http.cookiejar import CookieJar | |||
| from os.path import expanduser | |||
| from typing import List, Optional, Tuple, Union | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| import requests | |||
| from modelscope import __version__ | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_EMAIL, | |||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
| API_RESPONSE_FIELD_MESSAGE, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||
| ModelVisibility) | |||
| DEFAULT_CREDENTIALS_PATH, | |||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||
| Licenses, ModelVisibility) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, RequestError, | |||
| datahub_raise_on_error, | |||
| NotLoginException, NoValidRevisionError, | |||
| RequestError, datahub_raise_on_error, | |||
| handle_http_post_error, | |||
| handle_http_response, is_ok, raise_on_error) | |||
| handle_http_response, is_ok, | |||
| raise_for_http_status, raise_on_error) | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.utils.utils import (get_endpoint, | |||
| model_id_to_group_owner_name) | |||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION, | |||
| DatasetFormations, DatasetMetaFormats, | |||
| DownloadMode, ModelFile) | |||
| DEFAULT_REPOSITORY_REVISION, | |||
| MASTER_MODEL_BRANCH, DatasetFormations, | |||
| DatasetMetaFormats, DownloadMode, | |||
| ModelFile) | |||
| from modelscope.utils.logger import get_logger | |||
| # yapf: enable | |||
| from .utils.utils import (get_endpoint, get_release_datetime, | |||
| model_id_to_group_owner_name) | |||
| logger = get_logger() | |||
| @@ -46,6 +51,7 @@ class HubApi: | |||
| def __init__(self, endpoint=None): | |||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
| def login( | |||
| self, | |||
| @@ -65,8 +71,9 @@ class HubApi: | |||
| </Tip> | |||
| """ | |||
| path = f'{self.endpoint}/api/v1/login' | |||
| r = requests.post(path, json={'AccessToken': access_token}) | |||
| r.raise_for_status() | |||
| r = requests.post( | |||
| path, json={'AccessToken': access_token}, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| @@ -120,7 +127,8 @@ class HubApi: | |||
| 'Visibility': visibility, # server check | |||
| 'License': license | |||
| } | |||
| r = requests.post(path, json=body, cookies=cookies) | |||
| r = requests.post( | |||
| path, json=body, cookies=cookies, headers=self.headers) | |||
| handle_http_post_error(r, path, body) | |||
| raise_on_error(r.json()) | |||
| model_repo_url = f'{get_endpoint()}/{model_id}' | |||
| @@ -140,8 +148,8 @@ class HubApi: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}' | |||
| r = requests.delete(path, cookies=cookies) | |||
| r.raise_for_status() | |||
| r = requests.delete(path, cookies=cookies, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| raise_on_error(r.json()) | |||
| def get_model_url(self, model_id): | |||
| @@ -170,7 +178,7 @@ class HubApi: | |||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
| r = requests.get(path, cookies=cookies) | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -178,7 +186,7 @@ class HubApi: | |||
| else: | |||
| raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| def push_model(self, | |||
| model_id: str, | |||
| @@ -187,7 +195,7 @@ class HubApi: | |||
| license: str = Licenses.APACHE_V2, | |||
| chinese_name: Optional[str] = None, | |||
| commit_message: Optional[str] = 'upload model', | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION): | |||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): | |||
| """ | |||
| Upload model from a given directory to given repository. A valid model directory | |||
| must contain a configuration.json file. | |||
| @@ -269,7 +277,7 @@ class HubApi: | |||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||
| model_id, date) | |||
| repo.push(commit_message=commit_message, branch=revision) | |||
| repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) | |||
| except Exception: | |||
| raise | |||
| finally: | |||
| @@ -294,7 +302,8 @@ class HubApi: | |||
| path, | |||
| data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | |||
| (owner_or_group, page_number, page_size), | |||
| cookies=cookies) | |||
| cookies=cookies, | |||
| headers=self.headers) | |||
| handle_http_response(r, logger, cookies, 'list_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -303,7 +312,7 @@ class HubApi: | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| def _check_cookie(self, | |||
| @@ -318,10 +327,70 @@ class HubApi: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| return cookies | |||
| def list_model_revisions( | |||
| self, | |||
| model_id: str, | |||
| cutoff_timestamp: int = None, | |||
| use_cookies: Union[bool, CookieJar] = False) -> List[str]: | |||
| """Get model branch and tags. | |||
| Args: | |||
| model_id (str): The model id | |||
| cutoff_timestamp (int): Tags created before the cutoff will be included. | |||
| The timestamp is represented by the seconds elasped from the epoch time. | |||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
| will load cookie from local. Defaults to False. | |||
| Returns: | |||
| Tuple[List[str], List[str]]: Return list of branch name and tags | |||
| """ | |||
| cookies = self._check_cookie(use_cookies) | |||
| if cutoff_timestamp is None: | |||
| cutoff_timestamp = get_release_datetime() | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| info = d[API_RESPONSE_FIELD_DATA] | |||
| # tags returned from backend are guaranteed to be ordered by create-time | |||
| tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||
| ] if info['RevisionMap']['Tags'] else [] | |||
| return tags | |||
| def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): | |||
| release_timestamp = get_release_datetime() | |||
| current_timestamp = int(round(datetime.datetime.now().timestamp())) | |||
| # for active development in library codes (non-release-branches), release_timestamp | |||
| # is set to be a far-away-time-in-the-future, to ensure that we shall | |||
| # get the master-HEAD version from model repo by default (when no revision is provided) | |||
| if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: | |||
| branches, tags = self.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| if revision is None: | |||
| revision = MASTER_MODEL_BRANCH | |||
| logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||
| if revision not in branches and revision not in tags: | |||
| raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||
| else: | |||
| revisions = self.list_model_revisions( | |||
| model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||
| if revision is None: | |||
| if len(revisions) == 0: | |||
| raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||
| # tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||
| # we shall obtain the latest revision created earlier than release version of this branch | |||
| revision = revisions[0] | |||
| logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||
| else: | |||
| if revision not in revisions: | |||
| raise NotExistError( | |||
| 'The model: %s has no revision: %s !' % (model_id, revision)) | |||
| return revision | |||
| def get_model_branches_and_tags( | |||
| self, | |||
| model_id: str, | |||
| use_cookies: Union[bool, CookieJar] = False | |||
| use_cookies: Union[bool, CookieJar] = False, | |||
| ) -> Tuple[List[str], List[str]]: | |||
| """Get model branch and tags. | |||
| @@ -335,7 +404,7 @@ class HubApi: | |||
| cookies = self._check_cookie(use_cookies) | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||
| r = requests.get(path, cookies=cookies) | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| @@ -376,7 +445,11 @@ class HubApi: | |||
| if root is not None: | |||
| path = path + f'&Root={root}' | |||
| r = requests.get(path, cookies=cookies, headers=headers) | |||
| r = requests.get( | |||
| path, cookies=cookies, headers={ | |||
| **headers, | |||
| **self.headers | |||
| }) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| @@ -392,10 +465,9 @@ class HubApi: | |||
| def list_datasets(self): | |||
| path = f'{self.endpoint}/api/v1/datasets' | |||
| headers = None | |||
| params = {} | |||
| r = requests.get(path, params=params, headers=headers) | |||
| r.raise_for_status() | |||
| r = requests.get(path, params=params, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return [x['Name'] for x in dataset_list] | |||
| @@ -425,7 +497,7 @@ class HubApi: | |||
| dataset_id = resp['Data']['Id'] | |||
| dataset_type = resp['Data']['Type'] | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
| r = requests.get(datahub_url) | |||
| r = requests.get(datahub_url, headers=self.headers) | |||
| resp = r.json() | |||
| datahub_raise_on_error(datahub_url, resp) | |||
| file_list = resp['Data'] | |||
| @@ -445,7 +517,7 @@ class HubApi: | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| f'Revision={revision}&FilePath={file_path}' | |||
| r = requests.get(datahub_url) | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| local_path = os.path.join(cache_dir, file_path) | |||
| if os.path.exists(local_path): | |||
| logger.warning( | |||
| @@ -490,7 +562,8 @@ class HubApi: | |||
| f'ststoken?Revision={revision}' | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| r = requests.get(url=datahub_url, cookies=cookies) | |||
| r = requests.get( | |||
| url=datahub_url, cookies=cookies, headers=self.headers) | |||
| resp = r.json() | |||
| raise_on_error(resp) | |||
| return resp['Data'] | |||
| @@ -512,12 +585,12 @@ class HubApi: | |||
| def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | |||
| r = requests.post(url) | |||
| r.raise_for_status() | |||
| r = requests.post(url, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| @staticmethod | |||
| def datahub_remote_call(url): | |||
| r = requests.get(url) | |||
| r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
| resp = r.json() | |||
| datahub_raise_on_error(url, resp) | |||
| return resp['Data'] | |||
| @@ -531,6 +604,7 @@ class ModelScopeConfig: | |||
| COOKIES_FILE_NAME = 'cookies' | |||
| GIT_TOKEN_FILE_NAME = 'git_token' | |||
| USER_INFO_FILE_NAME = 'user' | |||
| USER_SESSION_ID_FILE_NAME = 'session' | |||
| @staticmethod | |||
| def make_sure_credential_path_exist(): | |||
| @@ -559,6 +633,23 @@ class ModelScopeConfig: | |||
| return cookies | |||
| return None | |||
| @staticmethod | |||
| def get_user_session_id(): | |||
| session_path = os.path.join(ModelScopeConfig.path_credential, | |||
| ModelScopeConfig.USER_SESSION_ID_FILE_NAME) | |||
| session_id = '' | |||
| if os.path.exists(session_path): | |||
| with open(session_path, 'rb') as f: | |||
| session_id = str(f.readline().strip(), encoding='utf-8') | |||
| return session_id | |||
| if session_id == '' or len(session_id) != 32: | |||
| session_id = str(uuid.uuid4().hex) | |||
| ModelScopeConfig.make_sure_credential_path_exist() | |||
| with open(session_path, 'w+') as wf: | |||
| wf.write(session_id) | |||
| return session_id | |||
| @staticmethod | |||
| def save_token(token: str): | |||
| ModelScopeConfig.make_sure_credential_path_exist() | |||
| @@ -607,3 +698,32 @@ class ModelScopeConfig: | |||
| except FileNotFoundError: | |||
| pass | |||
| return token | |||
| @staticmethod | |||
| def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
| """Formats a user-agent string with basic info about a request. | |||
| Args: | |||
| user_agent (`str`, `dict`, *optional*): | |||
| The user agent info in the form of a dictionary or a single string. | |||
| Returns: | |||
| The formatted user-agent string. | |||
| """ | |||
| env = 'custom' | |||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | |||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||
| __version__, | |||
| platform.python_version(), | |||
| ModelScopeConfig.get_user_session_id(), | |||
| platform.platform(), | |||
| platform.processor(), | |||
| env, | |||
| ) | |||
| if isinstance(user_agent, dict): | |||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
| elif isinstance(user_agent, str): | |||
| ua += ';' + user_agent | |||
| return ua | |||
| @@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||
| API_RESPONSE_FIELD_USERNAME = 'Username' | |||
| API_RESPONSE_FIELD_EMAIL = 'Email' | |||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||
| class Licenses(object): | |||
| @@ -3,7 +3,6 @@ from abc import ABC | |||
| from http import HTTPStatus | |||
| from typing import Optional | |||
| import attrs | |||
| import json | |||
| import requests | |||
| from attrs import asdict, define, field, validators | |||
| @@ -12,7 +11,8 @@ from modelscope.hub.api import ModelScopeConfig | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_MESSAGE) | |||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||
| RequestError, handle_http_response, is_ok) | |||
| RequestError, handle_http_response, is_ok, | |||
| raise_for_http_status) | |||
| from modelscope.hub.utils.utils import get_endpoint | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -188,6 +188,7 @@ class ServiceDeployer(object): | |||
| def __init__(self, endpoint=None): | |||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
| self.cookies = ModelScopeConfig.get_cookies() | |||
| if self.cookies is None: | |||
| raise NotLoginException( | |||
| @@ -227,12 +228,9 @@ class ServiceDeployer(object): | |||
| resource=resource, | |||
| provider=provider) | |||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
| body = attrs.asdict(create_params) | |||
| body = asdict(create_params) | |||
| r = requests.post( | |||
| path, | |||
| json=body, | |||
| cookies=self.cookies, | |||
| ) | |||
| path, json=body, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'create_service') | |||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
| if is_ok(r.json()): | |||
| @@ -241,7 +239,7 @@ class ServiceDeployer(object): | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | |||
| @@ -262,7 +260,7 @@ class ServiceDeployer(object): | |||
| params = GetServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies) | |||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'get_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -271,7 +269,7 @@ class ServiceDeployer(object): | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||
| @@ -293,7 +291,7 @@ class ServiceDeployer(object): | |||
| params = DeleteServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.delete(path, cookies=self.cookies) | |||
| r = requests.delete(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'delete_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -302,7 +300,7 @@ class ServiceDeployer(object): | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| def list(self, | |||
| @@ -328,7 +326,7 @@ class ServiceDeployer(object): | |||
| provider=provider, skip=skip, limit=limit) | |||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
| params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies) | |||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -337,5 +335,5 @@ class ServiceDeployer(object): | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| @@ -13,6 +13,10 @@ class NotSupportError(Exception): | |||
| pass | |||
| class NoValidRevisionError(Exception): | |||
| pass | |||
| class NotExistError(Exception): | |||
| pass | |||
| @@ -99,3 +103,33 @@ def datahub_raise_on_error(url, rsp): | |||
| raise RequestError( | |||
| f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | |||
| ) | |||
| def raise_for_http_status(rsp): | |||
| """ | |||
| Attempt to decode utf-8 first since some servers | |||
| localize reason strings, for invalid utf-8, fall back | |||
| to decoding with iso-8859-1. | |||
| """ | |||
| http_error_msg = '' | |||
| if isinstance(rsp.reason, bytes): | |||
| try: | |||
| reason = rsp.reason.decode('utf-8') | |||
| except UnicodeDecodeError: | |||
| reason = rsp.reason.decode('iso-8859-1') | |||
| else: | |||
| reason = rsp.reason | |||
| if 400 <= rsp.status_code < 500: | |||
| http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, | |||
| reason, rsp.url) | |||
| elif 500 <= rsp.status_code < 600: | |||
| http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, | |||
| reason, rsp.url) | |||
| if http_error_msg: | |||
| req = rsp.request | |||
| if req.method == 'POST': | |||
| http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) | |||
| raise HTTPError(http_error_msg, response=rsp) | |||
| @@ -2,13 +2,11 @@ | |||
| import copy | |||
| import os | |||
| import sys | |||
| import tempfile | |||
| from functools import partial | |||
| from http.cookiejar import CookieJar | |||
| from pathlib import Path | |||
| from typing import Dict, Optional, Union | |||
| from uuid import uuid4 | |||
| import requests | |||
| from tqdm import tqdm | |||
| @@ -23,7 +21,6 @@ from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| get_endpoint, model_id_to_group_owner_name) | |||
| SESSION_ID = uuid4().hex | |||
| logger = get_logger() | |||
| @@ -34,6 +31,7 @@ def model_file_download( | |||
| cache_dir: Optional[str] = None, | |||
| user_agent: Union[Dict, str, None] = None, | |||
| local_files_only: Optional[bool] = False, | |||
| cookies: Optional[CookieJar] = None, | |||
| ) -> Optional[str]: # pragma: no cover | |||
| """ | |||
| Download from a given URL and cache it if it's not already present in the | |||
| @@ -104,54 +102,47 @@ def model_file_download( | |||
| " online, set 'local_files_only' to False.") | |||
| _api = HubApi() | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| headers = { | |||
| 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
| } | |||
| if cookies is None: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| revision = _api.get_valid_revision( | |||
| model_id, revision=revision, cookies=cookies) | |||
| file_to_download_info = None | |||
| is_commit_id = False | |||
| if revision in branches or revision in tags: # The revision is version or tag, | |||
| # we need to confirm the version is up to date | |||
| # we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||
| model_files = _api.get_model_files( | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| use_cookies=False if cookies is None else cookies) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| if model_file['Path'] == file_path: | |||
| if cache.exists(model_file): | |||
| return cache.get_file_by_info(model_file) | |||
| else: | |||
| file_to_download_info = model_file | |||
| break | |||
| if file_to_download_info is None: | |||
| raise NotExistError('The file path: %s not exist in: %s' % | |||
| (file_path, model_id)) | |||
| else: # the revision is commit id. | |||
| cached_file_path = cache.get_file_by_path_and_commit_id( | |||
| file_path, revision) | |||
| if cached_file_path is not None: | |||
| file_name = os.path.basename(cached_file_path) | |||
| logger.info( | |||
| f'File {file_name} already in cache, skip downloading!') | |||
| return cached_file_path # the file is in cache. | |||
| is_commit_id = True | |||
| # we need to confirm the version is up-to-date | |||
| # we need to get the file list to check if the latest version is cached, if so return, otherwise download | |||
| model_files = _api.get_model_files( | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| use_cookies=False if cookies is None else cookies) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| if model_file['Path'] == file_path: | |||
| if cache.exists(model_file): | |||
| logger.info( | |||
| f'File {model_file["Name"]} already in cache, skip downloading!' | |||
| ) | |||
| return cache.get_file_by_info(model_file) | |||
| else: | |||
| file_to_download_info = model_file | |||
| break | |||
| if file_to_download_info is None: | |||
| raise NotExistError('The file path: %s not exist in: %s' % | |||
| (file_path, model_id)) | |||
| # we need to download again | |||
| url_to_download = get_file_download_url(model_id, file_path, revision) | |||
| file_to_download_info = { | |||
| 'Path': | |||
| file_path, | |||
| 'Revision': | |||
| revision if is_commit_id else file_to_download_info['Revision'], | |||
| FILE_HASH: | |||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||
| file_to_download_info[FILE_HASH] | |||
| 'Path': file_path, | |||
| 'Revision': file_to_download_info['Revision'], | |||
| FILE_HASH: file_to_download_info[FILE_HASH] | |||
| } | |||
| temp_file_name = next(tempfile._get_candidate_names()) | |||
| @@ -170,25 +161,6 @@ def model_file_download( | |||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
| """Formats a user-agent string with basic info about a request. | |||
| Args: | |||
| user_agent (`str`, `dict`, *optional*): | |||
| The user agent info in the form of a dictionary or a single string. | |||
| Returns: | |||
| The formatted user-agent string. | |||
| """ | |||
| ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||
| if isinstance(user_agent, dict): | |||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
| elif isinstance(user_agent, str): | |||
| ua = user_agent | |||
| return ua | |||
| def get_file_download_url(model_id: str, file_path: str, revision: str): | |||
| """ | |||
| Format file download url according to `model_id`, `revision` and `file_path`. | |||
| @@ -5,6 +5,7 @@ import subprocess | |||
| from typing import List | |||
| from modelscope.utils.logger import get_logger | |||
| from ..utils.constant import MASTER_MODEL_BRANCH | |||
| from .errors import GitError | |||
| logger = get_logger() | |||
| @@ -227,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| files.append(line.split(' ')[-1]) | |||
| return files | |||
| def tag(self, | |||
| repo_dir: str, | |||
| tag_name: str, | |||
| message: str, | |||
| ref: str = MASTER_MODEL_BRANCH): | |||
| cmd_args = [ | |||
| '-C', repo_dir, 'tag', tag_name, '-m', | |||
| '"%s"' % message, ref | |||
| ] | |||
| rsp = self._run_git_command(*cmd_args) | |||
| logger.debug(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| def push_tag(self, repo_dir: str, tag_name): | |||
| cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] | |||
| rsp = self._run_git_command(*cmd_args) | |||
| logger.debug(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| @@ -5,7 +5,8 @@ from typing import Optional | |||
| from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION) | |||
| DEFAULT_REPOSITORY_REVISION, | |||
| MASTER_MODEL_BRANCH) | |||
| from modelscope.utils.logger import get_logger | |||
| from .git import GitCommandWrapper | |||
| from .utils.utils import get_endpoint | |||
| @@ -20,7 +21,7 @@ class Repository: | |||
| def __init__(self, | |||
| model_dir: str, | |||
| clone_from: str, | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| auth_token: Optional[str] = None, | |||
| git_path: Optional[str] = None): | |||
| """ | |||
| @@ -89,7 +90,8 @@ class Repository: | |||
| def push(self, | |||
| commit_message: str, | |||
| branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| force: bool = False): | |||
| """Push local files to remote, this method will do. | |||
| git pull | |||
| @@ -116,14 +118,48 @@ class Repository: | |||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
| self.git_wrapper.pull(self.model_dir) | |||
| self.git_wrapper.add(self.model_dir, all_files=True) | |||
| self.git_wrapper.commit(self.model_dir, commit_message) | |||
| self.git_wrapper.push( | |||
| repo_dir=self.model_dir, | |||
| token=self.auth_token, | |||
| url=url, | |||
| local_branch=branch, | |||
| remote_branch=branch) | |||
| local_branch=local_branch, | |||
| remote_branch=remote_branch) | |||
| def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): | |||
| """Create a new tag. | |||
| Args: | |||
| tag_name (str): The name of the tag | |||
| message (str): The tag message. | |||
| ref (str): The tag reference, can be commit id or branch. | |||
| """ | |||
| if tag_name is None or tag_name == '': | |||
| msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' | |||
| raise InvalidParameter(msg) | |||
| if message is None or message == '': | |||
| msg = 'We use annotated tag, therefore message cannot None or empty.' | |||
| self.git_wrapper.tag( | |||
| repo_dir=self.model_dir, | |||
| tag_name=tag_name, | |||
| message=message, | |||
| ref=ref) | |||
| def tag_and_push(self, | |||
| tag_name: str, | |||
| message: str, | |||
| ref: str = MASTER_MODEL_BRANCH): | |||
| """Create tag and push to remote | |||
| Args: | |||
| tag_name (str): The name of the tag | |||
| message (str): The tag message. | |||
| ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. | |||
| """ | |||
| self.tag(tag_name, message, ref) | |||
| self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) | |||
| class DatasetRepository: | |||
| @@ -2,6 +2,7 @@ | |||
| import os | |||
| import tempfile | |||
| from http.cookiejar import CookieJar | |||
| from pathlib import Path | |||
| from typing import Dict, Optional, Union | |||
| @@ -9,9 +10,7 @@ from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .constants import FILE_HASH | |||
| from .errors import NotExistError | |||
| from .file_download import (get_file_download_url, http_get_file, | |||
| http_user_agent) | |||
| from .file_download import get_file_download_url, http_get_file | |||
| from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| model_id_to_group_owner_name) | |||
| @@ -23,7 +22,8 @@ def snapshot_download(model_id: str, | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| cache_dir: Union[str, Path, None] = None, | |||
| user_agent: Optional[Union[Dict, str]] = None, | |||
| local_files_only: Optional[bool] = False) -> str: | |||
| local_files_only: Optional[bool] = False, | |||
| cookies: Optional[CookieJar] = None) -> str: | |||
| """Download all files of a repo. | |||
| Downloads a whole snapshot of a repo's files at the specified revision. This | |||
| is useful when you want all files from a repo, because you don't know which | |||
| @@ -81,15 +81,15 @@ def snapshot_download(model_id: str, | |||
| ) # we can not confirm the cached file is for snapshot 'revision' | |||
| else: | |||
| # make headers | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| headers = { | |||
| 'user-agent': | |||
| ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
| } | |||
| _api = HubApi() | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| # get file list from model repo | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| if revision not in branches and revision not in tags: | |||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||
| % revision) | |||
| if cookies is None: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| revision = _api.get_valid_revision( | |||
| model_id, revision=revision, cookies=cookies) | |||
| snapshot_header = headers if 'CI_TEST' in os.environ else { | |||
| **headers, | |||
| @@ -110,7 +110,7 @@ def snapshot_download(model_id: str, | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||
| # check model_file is exist in cache, if existed, skip download, otherwise download | |||
| if cache.exists(model_file): | |||
| file_name = os.path.basename(model_file['Name']) | |||
| logger.info( | |||
| @@ -2,11 +2,12 @@ | |||
| import hashlib | |||
| import os | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
| DEFAULT_MODELSCOPE_GROUP, | |||
| MODEL_ID_SEPARATOR, | |||
| MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||
| MODELSCOPE_URL_SCHEME) | |||
| from modelscope.hub.errors import FileIntegrityError | |||
| from modelscope.utils.file_utils import get_default_cache_dir | |||
| @@ -37,6 +38,18 @@ def get_cache_dir(model_id: Optional[str] = None): | |||
| base_path, model_id + '/') | |||
| def get_release_datetime(): | |||
| if MODELSCOPE_SDK_DEBUG in os.environ: | |||
| rt = int(round(datetime.now().timestamp())) | |||
| else: | |||
| from modelscope import version | |||
| rt = int( | |||
| round( | |||
| datetime.strptime(version.__release_datetime__, | |||
| '%Y-%m-%d %H:%M:%S').timestamp())) | |||
| return rt | |||
| def get_endpoint(): | |||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
| DEFAULT_MODELSCOPE_DOMAIN) | |||
| @@ -675,4 +675,8 @@ class MsDataset: | |||
| revision=revision, | |||
| auth_token=auth_token, | |||
| git_path=git_path) | |||
| _repo.push(commit_message=commit_message, branch=revision, force=force) | |||
| _repo.push( | |||
| commit_message=commit_message, | |||
| local_branch=revision, | |||
| remote_branch=revision, | |||
| force=force) | |||
| @@ -311,7 +311,9 @@ class Frameworks(object): | |||
| kaldi = 'kaldi' | |||
| DEFAULT_MODEL_REVISION = 'master' | |||
| DEFAULT_MODEL_REVISION = None | |||
| MASTER_MODEL_BRANCH = 'master' | |||
| DEFAULT_REPOSITORY_REVISION = 'master' | |||
| DEFAULT_DATASET_REVISION = 'master' | |||
| DEFAULT_DATASET_NAMESPACE = 'modelscope' | |||
| @@ -1 +1,5 @@ | |||
| # Make sure to modify __release_datetime__ to release time when making official release. | |||
| __version__ = '0.5.0' | |||
| # default release datetime for branches under active development is set | |||
| # to be a time far-far-away-into-the-future | |||
| __release_datetime__ = '2099-10-13 08:56:12' | |||
| @@ -25,10 +25,10 @@ class HubOperationTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_name = 'op-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.revision = 'v0.1_test_revision' | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PUBLIC, | |||
| @@ -46,6 +46,7 @@ class HubOperationTest(unittest.TestCase): | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name)) | |||
| repo.push('add model') | |||
| repo.tag_and_push(self.revision, 'Test revision') | |||
| def test_model_repo_creation(self): | |||
| # change to proper model names before use | |||
| @@ -61,7 +62,9 @@ class HubOperationTest(unittest.TestCase): | |||
| def test_download_single_file(self): | |||
| self.prepare_case() | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, file_path=download_model_file_name) | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| revision=self.revision) | |||
| assert os.path.exists(downloaded_file) | |||
| mdtime1 = os.path.getmtime(downloaded_file) | |||
| # download again | |||
| @@ -78,17 +81,16 @@ class HubOperationTest(unittest.TestCase): | |||
| assert os.path.exists(downloaded_file_path) | |||
| mdtime1 = os.path.getmtime(downloaded_file_path) | |||
| # download again | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| snapshot_path = snapshot_download( | |||
| model_id=self.model_id, revision=self.revision) | |||
| mdtime2 = os.path.getmtime(downloaded_file_path) | |||
| assert mdtime1 == mdtime2 | |||
| model_file_download( | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name) # not add counter | |||
| def test_download_public_without_login(self): | |||
| self.prepare_case() | |||
| rmtree(ModelScopeConfig.path_credential) | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| snapshot_path = snapshot_download( | |||
| model_id=self.model_id, revision=self.revision) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| @@ -96,26 +98,38 @@ class HubOperationTest(unittest.TestCase): | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| revision=self.revision, | |||
| cache_dir=temporary_dir) | |||
| assert os.path.exists(downloaded_file) | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| def test_snapshot_delete_download_cache_file(self): | |||
| self.prepare_case() | |||
| snapshot_path = snapshot_download(model_id=self.model_id) | |||
| snapshot_path = snapshot_download( | |||
| model_id=self.model_id, revision=self.revision) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| os.remove(downloaded_file_path) | |||
| # download again in cache | |||
| file_download_path = model_file_download( | |||
| model_id=self.model_id, file_path=ModelFile.README) | |||
| model_id=self.model_id, | |||
| file_path=ModelFile.README, | |||
| revision=self.revision) | |||
| assert os.path.exists(file_download_path) | |||
| # deleted file need download again | |||
| file_download_path = model_file_download( | |||
| model_id=self.model_id, file_path=download_model_file_name) | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| revision=self.revision) | |||
| assert os.path.exists(file_download_path) | |||
| def test_snapshot_download_default_revision(self): | |||
| pass # TOTO | |||
| def test_file_download_default_revision(self): | |||
| pass # TODO | |||
| def get_model_download_times(self): | |||
| url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| @@ -17,23 +17,34 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||
| TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, | |||
| delete_credential) | |||
| download_model_file_name = 'test.bin' | |||
| class HubPrivateFileDownloadTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_name = 'pf-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.revision = 'v0.1_test_revision' | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||
| visibility=ModelVisibility.PRIVATE, | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||
| ) | |||
| def prepare_case(self): | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name)) | |||
| repo.push('add model') | |||
| repo.tag_and_push(self.revision, 'Test revision') | |||
| def tearDown(self): | |||
| # credential may deleted or switch login name, we need re-login here | |||
| # to ensure the temporary model is deleted. | |||
| @@ -42,49 +53,67 @@ class HubPrivateFileDownloadTest(unittest.TestCase): | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def test_snapshot_download_private_model(self): | |||
| snapshot_path = snapshot_download(self.model_id) | |||
| self.prepare_case() | |||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||
| def test_snapshot_download_private_model_no_permission(self): | |||
| self.prepare_case() | |||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||
| with self.assertRaises(HTTPError): | |||
| snapshot_download(self.model_id) | |||
| snapshot_download(self.model_id, self.revision) | |||
| def test_snapshot_download_private_model_without_login(self): | |||
| self.prepare_case() | |||
| delete_credential() | |||
| with self.assertRaises(HTTPError): | |||
| snapshot_download(self.model_id) | |||
| snapshot_download(self.model_id, self.revision) | |||
| def test_download_file_private_model(self): | |||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||
| self.prepare_case() | |||
| file_path = model_file_download(self.model_id, ModelFile.README, | |||
| self.revision) | |||
| assert os.path.exists(file_path) | |||
| def test_download_file_private_model_no_permission(self): | |||
| self.prepare_case() | |||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) | |||
| with self.assertRaises(HTTPError): | |||
| model_file_download(self.model_id, ModelFile.README) | |||
| model_file_download(self.model_id, ModelFile.README, self.revision) | |||
| def test_download_file_private_model_without_login(self): | |||
| self.prepare_case() | |||
| delete_credential() | |||
| with self.assertRaises(HTTPError): | |||
| model_file_download(self.model_id, ModelFile.README) | |||
| model_file_download(self.model_id, ModelFile.README, self.revision) | |||
| def test_snapshot_download_local_only(self): | |||
| self.prepare_case() | |||
| with self.assertRaises(ValueError): | |||
| snapshot_download(self.model_id, local_files_only=True) | |||
| snapshot_path = snapshot_download(self.model_id) | |||
| snapshot_download( | |||
| self.model_id, self.revision, local_files_only=True) | |||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||
| assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) | |||
| snapshot_path = snapshot_download(self.model_id, local_files_only=True) | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, self.revision, local_files_only=True) | |||
| assert os.path.exists(snapshot_path) | |||
| def test_file_download_local_only(self): | |||
| self.prepare_case() | |||
| with self.assertRaises(ValueError): | |||
| model_file_download( | |||
| self.model_id, ModelFile.README, local_files_only=True) | |||
| file_path = model_file_download(self.model_id, ModelFile.README) | |||
| self.model_id, | |||
| ModelFile.README, | |||
| self.revision, | |||
| local_files_only=True) | |||
| file_path = model_file_download(self.model_id, ModelFile.README, | |||
| self.revision) | |||
| assert os.path.exists(file_path) | |||
| file_path = model_file_download( | |||
| self.model_id, ModelFile.README, local_files_only=True) | |||
| self.model_id, | |||
| ModelFile.README, | |||
| revision=self.revision, | |||
| local_files_only=True) | |||
| assert os.path.exists(file_path) | |||
| @@ -21,13 +21,12 @@ class HubPrivateRepositoryTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_name = 'pr-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PRIVATE, # 1-private, 5-public | |||
| visibility=ModelVisibility.PRIVATE, | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||
| ) | |||
| @@ -22,6 +22,7 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||
| logger = get_logger() | |||
| logger.setLevel('DEBUG') | |||
| DEFAULT_GIT_PATH = 'git' | |||
| download_model_file_name = 'test.bin' | |||
| class HubRepositoryTest(unittest.TestCase): | |||
| @@ -29,13 +30,13 @@ class HubRepositoryTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_cwd = os.getcwd() | |||
| self.api = HubApi() | |||
| # note this is temporary before official account management is ready | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = uuid.uuid4().hex | |||
| self.model_name = 'repo-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.revision = 'v0.1_test_revision' | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PUBLIC, # 1-private, 5-public | |||
| visibility=ModelVisibility.PUBLIC, | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||
| ) | |||
| @@ -67,9 +68,10 @@ class HubRepositoryTest(unittest.TestCase): | |||
| os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) | |||
| os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) | |||
| repo.push('test') | |||
| add1 = model_file_download(self.model_id, 'add1.py') | |||
| repo.tag_and_push(self.revision, 'Test revision') | |||
| add1 = model_file_download(self.model_id, 'add1.py', self.revision) | |||
| assert os.path.exists(add1) | |||
| add2 = model_file_download(self.model_id, 'add2.py') | |||
| add2 = model_file_download(self.model_id, 'add2.py', self.revision) | |||
| assert os.path.exists(add2) | |||
| # check lfs files. | |||
| git_wrapper = GitCommandWrapper() | |||
| @@ -0,0 +1,145 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import uuid | |||
| from datetime import datetime | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.errors import NotExistError, NoValidRevisionError | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||
| TEST_MODEL_ORG) | |||
| logger = get_logger() | |||
| logger.setLevel('DEBUG') | |||
| download_model_file_name = 'test.bin' | |||
| download_model_file_name2 = 'test2.bin' | |||
| class HubRevisionTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.api = HubApi() | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = 'rv-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.revision = 'v0.1_test_revision' | |||
| self.revision2 = 'v0.2_test_revision' | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PUBLIC, | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||
| ) | |||
| def tearDown(self): | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def prepare_repo_data(self): | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| self.repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name)) | |||
| self.repo.push('add model') | |||
| self.repo.tag_and_push(self.revision, 'Test revision') | |||
| def test_no_tag(self): | |||
| with self.assertRaises(NoValidRevisionError): | |||
| snapshot_download(self.model_id, None) | |||
| with self.assertRaises(NoValidRevisionError): | |||
| model_file_download(self.model_id, ModelFile.README) | |||
| def test_with_only_one_tag(self): | |||
| self.prepare_repo_data() | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, ModelFile.README, cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| def add_new_file_and_tag(self): | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||
| self.repo.push('add new file') | |||
| self.repo.tag_and_push(self.revision2, 'Test revision') | |||
| def test_snapshot_download_different_revision(self): | |||
| self.prepare_repo_data() | |||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('First time stamp: %s' % t1) | |||
| snapshot_path = snapshot_download(self.model_id, self.revision) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| self.add_new_file_and_tag() | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, | |||
| revision=self.revision, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| assert not os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name2)) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, | |||
| revision=self.revision2, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name2)) | |||
| def test_file_download_different_revision(self): | |||
| self.prepare_repo_data() | |||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('First time stamp: %s' % t1) | |||
| file_path = model_file_download(self.model_id, | |||
| download_model_file_name, | |||
| self.revision) | |||
| assert os.path.exists(file_path) | |||
| self.add_new_file_and_tag() | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| revision=self.revision, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| with self.assertRaises(NotExistError): | |||
| model_file_download( | |||
| self.model_id, | |||
| download_model_file_name2, | |||
| revision=self.revision, | |||
| cache_dir=temp_cache_dir) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| revision=self.revision2, | |||
| cache_dir=temp_cache_dir) | |||
| print('Downloaded file path: %s' % file_path) | |||
| assert os.path.exists(file_path) | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name2, | |||
| revision=self.revision2, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,190 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import tempfile | |||
| import time | |||
| import unittest | |||
| import uuid | |||
| from datetime import datetime | |||
| from unittest import mock | |||
| from modelscope import version | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses, | |||
| ModelVisibility) | |||
| from modelscope.hub.errors import NotExistError | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.utils.logger import get_logger | |||
| from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, | |||
| TEST_MODEL_ORG) | |||
| logger = get_logger() | |||
| logger.setLevel('DEBUG') | |||
| download_model_file_name = 'test.bin' | |||
| download_model_file_name2 = 'test2.bin' | |||
| class HubRevisionTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.api = HubApi() | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| self.model_name = 'rvr-%s' % (uuid.uuid4().hex) | |||
| self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) | |||
| self.revision = 'v0.1_test_revision' | |||
| self.revision2 = 'v0.2_test_revision' | |||
| self.api.create_model( | |||
| model_id=self.model_id, | |||
| visibility=ModelVisibility.PUBLIC, | |||
| license=Licenses.APACHE_V2, | |||
| chinese_name=TEST_MODEL_CHINESE_NAME, | |||
| ) | |||
| names_to_remove = {MODELSCOPE_SDK_DEBUG} | |||
| self.modified_environ = { | |||
| k: v | |||
| for k, v in os.environ.items() if k not in names_to_remove | |||
| } | |||
| def tearDown(self): | |||
| self.api.delete_model(model_id=self.model_id) | |||
| def prepare_repo_data(self): | |||
| temporary_dir = tempfile.mkdtemp() | |||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | |||
| self.repo = Repository(self.model_dir, clone_from=self.model_id) | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name)) | |||
| self.repo.push('add model') | |||
| def prepare_repo_data_and_tag(self): | |||
| self.prepare_repo_data() | |||
| self.repo.tag_and_push(self.revision, 'Test revision') | |||
| def add_new_file_and_tag_to_repo(self): | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||
| self.repo.push('add new file') | |||
| self.repo.tag_and_push(self.revision2, 'Test revision') | |||
| def add_new_file_and_branch_to_repo(self, branch_name): | |||
| os.system("echo 'testtest'>%s" | |||
| % os.path.join(self.model_dir, download_model_file_name2)) | |||
| self.repo.push('add new file', remote_branch=branch_name) | |||
| def test_dev_mode_default_master(self): | |||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
| self.prepare_repo_data() # no tag, default get master | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| def test_dev_mode_specify_branch(self): | |||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
| self.prepare_repo_data() # no tag, default get master | |||
| branch_name = 'test' | |||
| self.add_new_file_and_branch_to_repo(branch_name) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, | |||
| revision=branch_name, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| revision=branch_name, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| def test_snapshot_download_revision(self): | |||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
| self.prepare_repo_data_and_tag() | |||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('First time: %s' % t1) | |||
| time.sleep(10) | |||
| self.add_new_file_and_tag_to_repo() | |||
| t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('Secnod time: %s' % t2) | |||
| # set | |||
| release_datetime_backup = version.__release_datetime__ | |||
| logger.info('Origin __release_datetime__: %s' | |||
| % version.__release_datetime__) | |||
| try: | |||
| logger.info('Setting __release_datetime__ to: %s' % t1) | |||
| version.__release_datetime__ = t1 | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| assert not os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name2)) | |||
| version.__release_datetime__ = t2 | |||
| logger.info('Setting __release_datetime__ to: %s' % t2) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| snapshot_path = snapshot_download( | |||
| self.model_id, cache_dir=temp_cache_dir) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name)) | |||
| assert os.path.exists( | |||
| os.path.join(snapshot_path, download_model_file_name2)) | |||
| finally: | |||
| version.__release_datetime__ = release_datetime_backup | |||
| def test_file_download_revision(self): | |||
| with mock.patch.dict(os.environ, self.modified_environ, clear=True): | |||
| self.prepare_repo_data_and_tag() | |||
| t1 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('First time stamp: %s' % t1) | |||
| time.sleep(10) | |||
| self.add_new_file_and_tag_to_repo() | |||
| t2 = datetime.now().isoformat(sep=' ', timespec='seconds') | |||
| logger.info('Second time: %s' % t2) | |||
| release_datetime_backup = version.__release_datetime__ | |||
| logger.info('Origin __release_datetime__: %s' | |||
| % version.__release_datetime__) | |||
| try: | |||
| version.__release_datetime__ = t1 | |||
| logger.info('Setting __release_datetime__ to: %s' % t1) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| with self.assertRaises(NotExistError): | |||
| model_file_download( | |||
| self.model_id, | |||
| download_model_file_name2, | |||
| cache_dir=temp_cache_dir) | |||
| version.__release_datetime__ = t2 | |||
| logger.info('Setting __release_datetime__ to: %s' % t2) | |||
| with tempfile.TemporaryDirectory() as temp_cache_dir: | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name, | |||
| cache_dir=temp_cache_dir) | |||
| print('Downloaded file path: %s' % file_path) | |||
| assert os.path.exists(file_path) | |||
| file_path = model_file_download( | |||
| self.model_id, | |||
| download_model_file_name2, | |||
| cache_dir=temp_cache_dir) | |||
| assert os.path.exists(file_path) | |||
| finally: | |||
| version.__release_datetime__ = release_datetime_backup | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||