Browse Source

[to #46289830]feat: hub sdk support retry and continue-download after error

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10814720
master^2
mulin.lyh 3 years ago
parent
commit
70deb0190b
4 changed files with 248 additions and 51 deletions
  1. +38
    -24
      modelscope/hub/api.py
  2. +4
    -0
      modelscope/hub/constants.py
  3. +42
    -27
      modelscope/hub/file_download.py
  4. +164
    -0
      tests/hub/test_hub_retry.py

+ 38
- 24
modelscope/hub/api.py View File

@@ -2,6 +2,7 @@

# yapf: disable
import datetime
import functools
import os
import pickle
import platform
@@ -14,10 +15,12 @@ from http.cookiejar import CookieJar
from os.path import expanduser
from typing import Dict, List, Optional, Tuple, Union

import requests
from requests import Session
from requests.adapters import HTTPAdapter, Retry

from modelscope import __version__
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_EMAIL,
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
API_RESPONSE_FIELD_MESSAGE,
@@ -25,7 +28,8 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
DEFAULT_CREDENTIALS_PATH,
MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME,
ONE_YEAR_SECONDS, Licenses,
ONE_YEAR_SECONDS,
REQUESTS_API_HTTP_METHOD, Licenses,
ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, NoValidRevisionError,
@@ -54,6 +58,17 @@ class HubApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else get_endpoint()
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
self.session = Session()
retry = Retry(total=2, read=2, connect=2, backoff_factor=1,
status_forcelist=(500, 502, 503, 504),)
adapter = HTTPAdapter(max_retries=retry)
self.session.mount('http://', adapter)
self.session.mount('https://', adapter)
# set http timeout
for method in REQUESTS_API_HTTP_METHOD:
setattr(self.session,
method,
functools.partial(getattr(self.session, method), timeout=API_HTTP_CLIENT_TIMEOUT))

def login(
self,
@@ -73,7 +88,7 @@ class HubApi:
</Tip>
"""
path = f'{self.endpoint}/api/v1/login'
r = requests.post(
r = self.session.post(
path, json={'AccessToken': access_token}, headers=self.headers)
raise_for_http_status(r)
d = r.json()
@@ -129,7 +144,7 @@ class HubApi:
'Visibility': visibility, # server check
'License': license
}
r = requests.post(
r = self.session.post(
path, json=body, cookies=cookies, headers=self.headers)
handle_http_post_error(r, path, body)
raise_on_error(r.json())
@@ -150,7 +165,7 @@ class HubApi:
raise ValueError('Token does not exist, please login first.')
path = f'{self.endpoint}/api/v1/models/{model_id}'

r = requests.delete(path, cookies=cookies, headers=self.headers)
r = self.session.delete(path, cookies=cookies, headers=self.headers)
raise_for_http_status(r)
raise_on_error(r.json())

@@ -183,7 +198,7 @@ class HubApi:
else:
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}'

r = requests.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id)
if r.status_code == HTTPStatus.OK:
if is_ok(r.json()):
@@ -311,7 +326,7 @@ class HubApi:
"""
cookies = ModelScopeConfig.get_cookies()
path = f'{self.endpoint}/api/v1/models/'
r = requests.put(
r = self.session.put(
path,
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
(owner_or_group, page_number, page_size),
@@ -360,7 +375,7 @@ class HubApi:
if cutoff_timestamp is None:
cutoff_timestamp = get_release_datetime()
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
r = requests.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id)
d = r.json()
raise_on_error(d)
@@ -422,7 +437,7 @@ class HubApi:
cookies = self._check_cookie(use_cookies)

path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies, headers=self.headers)
handle_http_response(r, logger, cookies, model_id)
d = r.json()
raise_on_error(d)
@@ -467,7 +482,7 @@ class HubApi:
if root is not None:
path = path + f'&Root={root}'

r = requests.get(
r = self.session.get(
path, cookies=cookies, headers={
**headers,
**self.headers
@@ -488,7 +503,7 @@ class HubApi:
def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets'
params = {}
r = requests.get(path, params=params, headers=self.headers)
r = self.session.get(path, params=params, headers=self.headers)
raise_for_http_status(r)
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
return [x['Name'] for x in dataset_list]
@@ -514,13 +529,13 @@ class HubApi:
os.makedirs(cache_dir, exist_ok=True)
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies()
r = requests.get(datahub_url, cookies=cookies)
r = self.session.get(datahub_url, cookies=cookies)
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
dataset_id = resp['Data']['Id']
dataset_type = resp['Data']['Type']
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
r = requests.get(datahub_url, cookies=cookies, headers=self.headers)
r = self.session.get(datahub_url, cookies=cookies, headers=self.headers)
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
file_list = resp['Data']
@@ -539,7 +554,7 @@ class HubApi:
if extension in dataset_meta_format:
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_path}'
r = requests.get(datahub_url, cookies=cookies)
r = self.session.get(datahub_url, cookies=cookies)
raise_for_http_status(r)
local_path = os.path.join(cache_dir, file_path)
if os.path.exists(local_path):
@@ -584,7 +599,7 @@ class HubApi:
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
f'ststoken?Revision={revision}'

r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers)
r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers)
resp = r.json()
raise_on_error(resp)
return resp['Data']
@@ -595,7 +610,7 @@ class HubApi:
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'

cookies = ModelScopeConfig.get_cookies()
resp = requests.get(url=url, cookies=cookies)
resp = self.session.get(url=url, cookies=cookies)
resp = resp.json()
raise_on_error(resp)
resp = resp['Data']
@@ -604,7 +619,7 @@ class HubApi:
def on_dataset_download(self, dataset_name: str, namespace: str) -> None:
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
cookies = ModelScopeConfig.get_cookies()
r = requests.post(url, cookies=cookies, headers=self.headers)
r = self.session.post(url, cookies=cookies, headers=self.headers)
raise_for_http_status(r)

def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
@@ -615,7 +630,7 @@ class HubApi:
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'

cookies = self.check_local_cookies(use_cookies=True)
resp = requests.delete(url=url, cookies=cookies)
resp = self.session.delete(url=url, cookies=cookies)
resp = resp.json()
raise_on_error(resp)
resp = resp['Message']
@@ -630,16 +645,15 @@ class HubApi:
f'&Revision={revision}'

cookies = self.check_local_cookies(use_cookies=True)
resp = requests.delete(url=url, cookies=cookies)
resp = self.session.delete(url=url, cookies=cookies)
resp = resp.json()
raise_on_error(resp)
resp = resp['Message']
return resp

@staticmethod
def datahub_remote_call(url):
def datahub_remote_call(self, url):
cookies = ModelScopeConfig.get_cookies()
r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()})
r = self.session.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()})
resp = r.json()
datahub_raise_on_error(url, resp)
return resp['Data']
@@ -661,7 +675,7 @@ class HubApi:

url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}'
cookies = ModelScopeConfig.get_cookies()
r = requests.post(url, cookies=cookies, headers=self.headers)
r = self.session.post(url, cookies=cookies, headers=self.headers)
resp = r.json()
raise_on_error(resp)
return resp['Message']


+ 4
- 0
modelscope/hub/constants.py View File

@@ -11,7 +11,11 @@ MODEL_ID_SEPARATOR = '/'
FILE_HASH = 'Sha256'
LOGGER_NAME = 'ModelScopeHub'
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
REQUESTS_API_HTTP_METHOD = ['get', 'head', 'post', 'put', 'patch', 'delete']
API_HTTP_CLIENT_TIMEOUT = 60
API_RESPONSE_FIELD_DATA = 'Data'
API_FILE_DOWNLOAD_RETRY_TIMES = 5
API_FILE_DOWNLOAD_CHUNK_SIZE = 4096
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken'
API_RESPONSE_FIELD_USERNAME = 'Username'
API_RESPONSE_FIELD_EMAIL = 'Email'


+ 42
- 27
modelscope/hub/file_download.py View File

@@ -9,13 +9,15 @@ from pathlib import Path
from typing import Dict, Optional, Union

import requests
from requests.adapters import Retry
from tqdm import tqdm

from modelscope import __version__
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import (API_FILE_DOWNLOAD_CHUNK_SIZE,
API_FILE_DOWNLOAD_RETRY_TIMES, FILE_HASH)
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger
from .constants import FILE_HASH
from .errors import FileDownloadError, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_cache_dir,
@@ -184,10 +186,7 @@ def http_get_file(
headers: Optional[Dict[str, str]] = None,
):
"""
Download remote file. Do not gobble up errors.
This method is only used by snapshot_download, since the behavior is quite different with single file download
TODO: consolidate with http_get_file() to avoild duplicate code

Download remote file, will retry 5 times before giving up on errors.
Args:
url(`str`):
actual download url of the file
@@ -204,30 +203,46 @@ def http_get_file(
total = -1
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = copy.deepcopy(headers)
with temp_file_manager() as temp_file:
logger.info('downloading %s to %s', url, temp_file.name)
headers = copy.deepcopy(headers)

r = requests.get(url, stream=True, headers=headers, cookies=cookies)
r.raise_for_status()

content_length = r.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None

progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=total,
initial=0,
desc='Downloading',
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
downloaded_size = temp_file.tell()
get_headers['Range'] = 'bytes=%d-' % downloaded_size
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=5)
r.raise_for_status()
content_length = r.headers.get('Content-Length')
total = int(
content_length) if content_length is not None else None
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=total,
initial=downloaded_size,
desc='Downloading',
)
for chunk in r.iter_content(
chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()

logger.info('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)


+ 164
- 0
tests/hub/test_hub_retry.py View File

@@ -0,0 +1,164 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from http.client import HTTPMessage, HTTPResponse
from io import StringIO
from unittest.mock import Mock, patch

import requests
from urllib3.exceptions import MaxRetryError

from modelscope.hub.api import HubApi
from modelscope.hub.file_download import http_get_file


class HubOperationTest(unittest.TestCase):

def setUp(self):
self.api = HubApi()
self.model_id = 'damo/ofa_text-to-image-synthesis_coco_large_en'

@patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
def test_retry_exception(self, getconn_mock):
getconn_mock.return_value.getresponse.side_effect = [
Mock(status=500, msg=HTTPMessage()),
Mock(status=502, msg=HTTPMessage()),
Mock(status=500, msg=HTTPMessage()),
]
with self.assertRaises(requests.exceptions.RetryError):
self.api.get_model_files(
model_id=self.model_id,
recursive=True,
)

@patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
def test_retry_and_success(self, getconn_mock):
response_body = '{"Code": 200, "Data": { "Files": [ {"CommitMessage": \
"update","CommittedDate": 1667548386,"CommitterName": "行嗔","InCheck": false, \
"IsLFS": false, "Mode": "33188", "Name": "README.md", "Path": "README.md", \
"Revision": "e45fcc158894f18a7a8cfa3caf8b3dd1a2b26dc9",\
"Sha256": "8bf99f410ae0a572e5a4a85a3949ad268d49023e5c6ef200c9bd4307f9ed0660", \
"Size": 6399, "Type": "blob" } ] }, "Message": "success",\
"RequestId": "8c2a8249-ce50-49f4-85ea-36debf918714","Success": true}'

first = 0

def get_content(p):
nonlocal first
if first > 0:
return None
else:
first += 1
return response_body.encode('utf-8')

rsp = HTTPResponse(getconn_mock)
rsp.status = 200
rsp.msg = HTTPMessage()
rsp.read = get_content
rsp.chunked = False
# retry 2 times and success.
getconn_mock.return_value.getresponse.side_effect = [
Mock(status=500, msg=HTTPMessage()),
Mock(
status=502,
msg=HTTPMessage(),
body=response_body,
read=StringIO(response_body)),
rsp,
]
model_files = self.api.get_model_files(
model_id=self.model_id,
recursive=True,
)
assert len(model_files) > 0

@patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
def test_retry_broken_continue(self, getconn_mock):
test_file_name = 'video_inpainting_test.mp4'
fp = 0

def get_content(content_length):
nonlocal fp
with open('data/test/videos/%s' % test_file_name, 'rb') as f:
f.seek(fp)
content = f.read(content_length)
fp += len(content)
return content

success_rsp = HTTPResponse(getconn_mock)
success_rsp.status = 200
success_rsp.msg = HTTPMessage()
success_rsp.msg.add_header('Content-Length', '2957783')
success_rsp.read = get_content
success_rsp.chunked = True

failed_rsp = HTTPResponse(getconn_mock)
failed_rsp.status = 502
failed_rsp.msg = HTTPMessage()
failed_rsp.msg.add_header('Content-Length', '2957783')
failed_rsp.read = get_content
failed_rsp.chunked = True

# retry 5 times and success.
getconn_mock.return_value.getresponse.side_effect = [
failed_rsp,
failed_rsp,
failed_rsp,
failed_rsp,
failed_rsp,
success_rsp,
]
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
http_get_file(
url=url,
local_dir='./',
file_name=test_file_name,
headers={},
cookies=None)

assert os.path.exists('./%s' % test_file_name)
os.remove('./%s' % test_file_name)

@patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
def test_retry_broken_continue_retry_failed(self, getconn_mock):
test_file_name = 'video_inpainting_test.mp4'
fp = 0

def get_content(content_length):
nonlocal fp
with open('data/test/videos/%s' % test_file_name, 'rb') as f:
f.seek(fp)
content = f.read(content_length)
fp += len(content)
return content

failed_rsp = HTTPResponse(getconn_mock)
failed_rsp.status = 502
failed_rsp.msg = HTTPMessage()
failed_rsp.msg.add_header('Content-Length', '2957783')
failed_rsp.read = get_content
failed_rsp.chunked = True

# retry 6 times and success.
getconn_mock.return_value.getresponse.side_effect = [
failed_rsp,
failed_rsp,
failed_rsp,
failed_rsp,
failed_rsp,
failed_rsp,
]
url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
with self.assertRaises(MaxRetryError):
http_get_file(
url=url,
local_dir='./',
file_name=test_file_name,
headers={},
cookies=None)

assert not os.path.exists('./%s' % test_file_name)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save