diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 60e0e274..b871a713 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -2,6 +2,7 @@ # yapf: disable import datetime +import functools import os import pickle import platform @@ -14,10 +15,12 @@ from http.cookiejar import CookieJar from os.path import expanduser from typing import Dict, List, Optional, Tuple, Union -import requests +from requests import Session +from requests.adapters import HTTPAdapter, Retry from modelscope import __version__ -from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, +from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT, + API_RESPONSE_FIELD_DATA, API_RESPONSE_FIELD_EMAIL, API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, API_RESPONSE_FIELD_MESSAGE, @@ -25,7 +28,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, DEFAULT_CREDENTIALS_PATH, MODELSCOPE_CLOUD_ENVIRONMENT, MODELSCOPE_CLOUD_USERNAME, - ONE_YEAR_SECONDS, Licenses, + ONE_YEAR_SECONDS, + REQUESTS_API_HTTP_METHOD, Licenses, ModelVisibility) from modelscope.hub.errors import (InvalidParameter, NotExistError, NotLoginException, NoValidRevisionError, @@ -54,6 +58,17 @@ class HubApi: def __init__(self, endpoint=None): self.endpoint = endpoint if endpoint is not None else get_endpoint() self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} + self.session = Session() + retry = Retry(total=2, read=2, connect=2, backoff_factor=1, + status_forcelist=(500, 502, 503, 504),) + adapter = HTTPAdapter(max_retries=retry) + self.session.mount('http://', adapter) + self.session.mount('https://', adapter) + # set http timeout + for method in REQUESTS_API_HTTP_METHOD: + setattr(self.session, + method, + functools.partial(getattr(self.session, method), timeout=API_HTTP_CLIENT_TIMEOUT)) def login( self, @@ -73,7 +88,7 @@ class HubApi: """ path = f'{self.endpoint}/api/v1/login' - r = requests.post( + r = self.session.post( path, json={'AccessToken': access_token}, headers=self.headers) raise_for_http_status(r) d = r.json() @@ -129,7 +144,7 @@ class HubApi: 'Visibility': visibility, # server check 'License': license } - r = requests.post( + r = self.session.post( path, json=body, cookies=cookies, headers=self.headers) handle_http_post_error(r, path, body) raise_on_error(r.json()) @@ -150,7 +165,7 @@ class HubApi: raise ValueError('Token does not exist, please login first.') path = f'{self.endpoint}/api/v1/models/{model_id}' - r = requests.delete(path, cookies=cookies, headers=self.headers) + r = self.session.delete(path, cookies=cookies, headers=self.headers) raise_for_http_status(r) raise_on_error(r.json()) @@ -183,7 +198,7 @@ class HubApi: else: path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' - r = requests.get(path, cookies=cookies, headers=self.headers) + r = self.session.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) if r.status_code == HTTPStatus.OK: if is_ok(r.json()): @@ -311,7 +326,7 @@ class HubApi: """ cookies = ModelScopeConfig.get_cookies() path = f'{self.endpoint}/api/v1/models/' - r = requests.put( + r = self.session.put( path, data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % (owner_or_group, page_number, page_size), @@ -360,7 +375,7 @@ class HubApi: if cutoff_timestamp is None: cutoff_timestamp = get_release_datetime() path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp - r = requests.get(path, cookies=cookies, headers=self.headers) + r = self.session.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) d = r.json() raise_on_error(d) @@ -422,7 +437,7 @@ class HubApi: cookies = self._check_cookie(use_cookies) path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' - r = requests.get(path, cookies=cookies, headers=self.headers) + r = self.session.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) d = r.json() raise_on_error(d) @@ -467,7 +482,7 @@ class HubApi: if root is not None: path = path + f'&Root={root}' - r = requests.get( + r = self.session.get( path, cookies=cookies, headers={ **headers, **self.headers @@ -488,7 +503,7 @@ class HubApi: def list_datasets(self): path = f'{self.endpoint}/api/v1/datasets' params = {} - r = requests.get(path, params=params, headers=self.headers) + r = self.session.get(path, params=params, headers=self.headers) raise_for_http_status(r) dataset_list = r.json()[API_RESPONSE_FIELD_DATA] return [x['Name'] for x in dataset_list] @@ -514,13 +529,13 @@ class HubApi: os.makedirs(cache_dir, exist_ok=True) datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' cookies = ModelScopeConfig.get_cookies() - r = requests.get(datahub_url, cookies=cookies) + r = self.session.get(datahub_url, cookies=cookies) resp = r.json() datahub_raise_on_error(datahub_url, resp) dataset_id = resp['Data']['Id'] dataset_type = resp['Data']['Type'] datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' - r = requests.get(datahub_url, cookies=cookies, headers=self.headers) + r = self.session.get(datahub_url, cookies=cookies, headers=self.headers) resp = r.json() datahub_raise_on_error(datahub_url, resp) file_list = resp['Data'] @@ -539,7 +554,7 @@ class HubApi: if extension in dataset_meta_format: datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ f'Revision={revision}&FilePath={file_path}' - r = requests.get(datahub_url, cookies=cookies) + r = self.session.get(datahub_url, cookies=cookies) raise_for_http_status(r) local_path = os.path.join(cache_dir, file_path) if os.path.exists(local_path): @@ -584,7 +599,7 @@ class HubApi: datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ f'ststoken?Revision={revision}' - r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers) + r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers) resp = r.json() raise_on_error(resp) return resp['Data'] @@ -595,7 +610,7 @@ class HubApi: f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' cookies = ModelScopeConfig.get_cookies() - resp = requests.get(url=url, cookies=cookies) + resp = self.session.get(url=url, cookies=cookies) resp = resp.json() raise_on_error(resp) resp = resp['Data'] @@ -604,7 +619,7 @@ class HubApi: def on_dataset_download(self, dataset_name: str, namespace: str) -> None: url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' cookies = ModelScopeConfig.get_cookies() - r = requests.post(url, cookies=cookies, headers=self.headers) + r = self.session.post(url, cookies=cookies, headers=self.headers) raise_for_http_status(r) def delete_oss_dataset_object(self, object_name: str, dataset_name: str, @@ -615,7 +630,7 @@ class HubApi: url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' cookies = self.check_local_cookies(use_cookies=True) - resp = requests.delete(url=url, cookies=cookies) + resp = self.session.delete(url=url, cookies=cookies) resp = resp.json() raise_on_error(resp) resp = resp['Message'] @@ -630,16 +645,15 @@ class HubApi: f'&Revision={revision}' cookies = self.check_local_cookies(use_cookies=True) - resp = requests.delete(url=url, cookies=cookies) + resp = self.session.delete(url=url, cookies=cookies) resp = resp.json() raise_on_error(resp) resp = resp['Message'] return resp - @staticmethod - def datahub_remote_call(url): + def datahub_remote_call(self, url): cookies = ModelScopeConfig.get_cookies() - r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) + r = self.session.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) resp = r.json() datahub_raise_on_error(url, resp) return resp['Data'] @@ -661,7 +675,7 @@ class HubApi: url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' cookies = ModelScopeConfig.get_cookies() - r = requests.post(url, cookies=cookies, headers=self.headers) + r = self.session.post(url, cookies=cookies, headers=self.headers) resp = r.json() raise_on_error(resp) return resp['Message'] diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 83991e4e..9d5881e8 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -11,7 +11,11 @@ MODEL_ID_SEPARATOR = '/' FILE_HASH = 'Sha256' LOGGER_NAME = 'ModelScopeHub' DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials') +REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete'] +API_HTTP_CLIENT_TIMEOUT = 60 API_RESPONSE_FIELD_DATA = 'Data' +API_FILE_DOWNLOAD_RETRY_TIMES = 5 +API_FILE_DOWNLOAD_CHUNK_SIZE = 4096 API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' API_RESPONSE_FIELD_USERNAME = 'Username' API_RESPONSE_FIELD_EMAIL = 'Email' diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 042ea6a6..dd062516 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -9,13 +9,15 @@ from pathlib import Path from typing import Dict, Optional, Union import requests +from requests.adapters import Retry from tqdm import tqdm from modelscope import __version__ from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.constants import (API_FILE_DOWNLOAD_CHUNK_SIZE, + API_FILE_DOWNLOAD_RETRY_TIMES, FILE_HASH) from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.logger import get_logger -from .constants import FILE_HASH from .errors import FileDownloadError, NotExistError from .utils.caching import ModelFileSystemCache from .utils.utils import (file_integrity_validation, get_cache_dir, @@ -184,10 +186,7 @@ def http_get_file( headers: Optional[Dict[str, str]] = None, ): """ - Download remote file. Do not gobble up errors. - This method is only used by snapshot_download, since the behavior is quite different with single file download - TODO: consolidate with http_get_file() to avoild duplicate code - + Download remote file, will retry 5 times before giving up on errors. Args: url(`str`): actual download url of the file @@ -204,30 +203,46 @@ def http_get_file( total = -1 temp_file_manager = partial( tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) - + get_headers = copy.deepcopy(headers) with temp_file_manager() as temp_file: logger.info('downloading %s to %s', url, temp_file.name) - headers = copy.deepcopy(headers) - - r = requests.get(url, stream=True, headers=headers, cookies=cookies) - r.raise_for_status() - - content_length = r.headers.get('Content-Length') - total = int(content_length) if content_length is not None else None - - progress = tqdm( - unit='B', - unit_scale=True, - unit_divisor=1024, - total=total, - initial=0, - desc='Downloading', - ) - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() + # retry sleep 0.5s, 1s, 2s, 4s + retry = Retry( + total=API_FILE_DOWNLOAD_RETRY_TIMES, + backoff_factor=1, + allowed_methods=['GET']) + while True: + try: + downloaded_size = temp_file.tell() + get_headers['Range'] = 'bytes=%d-' % downloaded_size + r = requests.get( + url, + stream=True, + headers=get_headers, + cookies=cookies, + timeout=5) + r.raise_for_status() + content_length = r.headers.get('Content-Length') + total = int( + content_length) if content_length is not None else None + progress = tqdm( + unit='B', + unit_scale=True, + unit_divisor=1024, + total=total, + initial=downloaded_size, + desc='Downloading', + ) + for chunk in r.iter_content( + chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + break + except (Exception) as e: # no matter what happen, we will retry. + retry = retry.increment('GET', url, error=e) + retry.sleep() logger.info('storing %s in cache at %s', url, local_dir) downloaded_length = os.path.getsize(temp_file.name) diff --git a/tests/hub/test_hub_retry.py b/tests/hub/test_hub_retry.py new file mode 100644 index 00000000..e294cb68 --- /dev/null +++ b/tests/hub/test_hub_retry.py @@ -0,0 +1,164 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest +from http.client import HTTPMessage, HTTPResponse +from io import StringIO +from unittest.mock import Mock, patch + +import requests +from urllib3.exceptions import MaxRetryError + +from modelscope.hub.api import HubApi +from modelscope.hub.file_download import http_get_file + + +class HubOperationTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.model_id = 'damo/ofa_text-to-image-synthesis_coco_large_en' + + @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn') + def test_retry_exception(self, getconn_mock): + getconn_mock.return_value.getresponse.side_effect = [ + Mock(status=500, msg=HTTPMessage()), + Mock(status=502, msg=HTTPMessage()), + Mock(status=500, msg=HTTPMessage()), + ] + with self.assertRaises(requests.exceptions.RetryError): + self.api.get_model_files( + model_id=self.model_id, + recursive=True, + ) + + @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn') + def test_retry_and_success(self, getconn_mock): + response_body = '{"Code": 200, "Data": { "Files": [ {"CommitMessage": \ + "update","CommittedDate": 1667548386,"CommitterName": "行嗔","InCheck": false, \ + "IsLFS": false, "Mode": "33188", "Name": "README.md", "Path": "README.md", \ + "Revision": "e45fcc158894f18a7a8cfa3caf8b3dd1a2b26dc9",\ + "Sha256": "8bf99f410ae0a572e5a4a85a3949ad268d49023e5c6ef200c9bd4307f9ed0660", \ + "Size": 6399, "Type": "blob" } ] }, "Message": "success",\ + "RequestId": "8c2a8249-ce50-49f4-85ea-36debf918714","Success": true}' + + first = 0 + + def get_content(p): + nonlocal first + if first > 0: + return None + else: + first += 1 + return response_body.encode('utf-8') + + rsp = HTTPResponse(getconn_mock) + rsp.status = 200 + rsp.msg = HTTPMessage() + rsp.read = get_content + rsp.chunked = False + # retry 2 times and success. + getconn_mock.return_value.getresponse.side_effect = [ + Mock(status=500, msg=HTTPMessage()), + Mock( + status=502, + msg=HTTPMessage(), + body=response_body, + read=StringIO(response_body)), + rsp, + ] + model_files = self.api.get_model_files( + model_id=self.model_id, + recursive=True, + ) + assert len(model_files) > 0 + + @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn') + def test_retry_broken_continue(self, getconn_mock): + test_file_name = 'video_inpainting_test.mp4' + fp = 0 + + def get_content(content_length): + nonlocal fp + with open('data/test/videos/%s' % test_file_name, 'rb') as f: + f.seek(fp) + content = f.read(content_length) + fp += len(content) + return content + + success_rsp = HTTPResponse(getconn_mock) + success_rsp.status = 200 + success_rsp.msg = HTTPMessage() + success_rsp.msg.add_header('Content-Length', '2957783') + success_rsp.read = get_content + success_rsp.chunked = True + + failed_rsp = HTTPResponse(getconn_mock) + failed_rsp.status = 502 + failed_rsp.msg = HTTPMessage() + failed_rsp.msg.add_header('Content-Length', '2957783') + failed_rsp.read = get_content + failed_rsp.chunked = True + + # retry 5 times and success. + getconn_mock.return_value.getresponse.side_effect = [ + failed_rsp, + failed_rsp, + failed_rsp, + failed_rsp, + failed_rsp, + success_rsp, + ] + url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name + http_get_file( + url=url, + local_dir='./', + file_name=test_file_name, + headers={}, + cookies=None) + + assert os.path.exists('./%s' % test_file_name) + os.remove('./%s' % test_file_name) + + @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn') + def test_retry_broken_continue_retry_failed(self, getconn_mock): + test_file_name = 'video_inpainting_test.mp4' + fp = 0 + + def get_content(content_length): + nonlocal fp + with open('data/test/videos/%s' % test_file_name, 'rb') as f: + f.seek(fp) + content = f.read(content_length) + fp += len(content) + return content + + failed_rsp = HTTPResponse(getconn_mock) + failed_rsp.status = 502 + failed_rsp.msg = HTTPMessage() + failed_rsp.msg.add_header('Content-Length', '2957783') + failed_rsp.read = get_content + failed_rsp.chunked = True + + # retry 6 times and success. + getconn_mock.return_value.getresponse.side_effect = [ + failed_rsp, + failed_rsp, + failed_rsp, + failed_rsp, + failed_rsp, + failed_rsp, + ] + url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name + with self.assertRaises(MaxRetryError): + http_get_file( + url=url, + local_dir='./', + file_name=test_file_name, + headers={}, + cookies=None) + + assert not os.path.exists('./%s' % test_file_name) + + +if __name__ == '__main__': + unittest.main()