添加文件下载完整性验证
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9700279
* [to #43913168]fix: add file download integrity check
master
| @@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO | |||
| DEFAULT_MODELSCOPE_GROUP = 'damo' | |||
| MODEL_ID_SEPARATOR = '/' | |||
| FILE_HASH = 'Sha256' | |||
| LOGGER_NAME = 'ModelScopeHub' | |||
| DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | |||
| API_RESPONSE_FIELD_DATA = 'Data' | |||
| @@ -23,6 +23,14 @@ class NotLoginException(Exception): | |||
| pass | |||
| class FileIntegrityError(Exception): | |||
| pass | |||
| class FileDownloadError(Exception): | |||
| pass | |||
| def is_ok(rsp): | |||
| """ Check the request is ok | |||
| @@ -16,10 +16,11 @@ from modelscope import __version__ | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .errors import NotExistError | |||
| from .constants import FILE_HASH | |||
| from .errors import FileDownloadError, NotExistError | |||
| from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import (get_cache_dir, get_endpoint, | |||
| model_id_to_group_owner_name) | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| get_endpoint, model_id_to_group_owner_name) | |||
| SESSION_ID = uuid4().hex | |||
| logger = get_logger() | |||
| @@ -143,24 +144,29 @@ def model_file_download( | |||
| # we need to download again | |||
| url_to_download = get_file_download_url(model_id, file_path, revision) | |||
| file_to_download_info = { | |||
| 'Path': file_path, | |||
| 'Path': | |||
| file_path, | |||
| 'Revision': | |||
| revision if is_commit_id else file_to_download_info['Revision'] | |||
| revision if is_commit_id else file_to_download_info['Revision'], | |||
| FILE_HASH: | |||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||
| file_to_download_info[FILE_HASH] | |||
| } | |||
| # Prevent parallel downloads of the same file with a lock. | |||
| lock_path = cache.get_root_location() + '.lock' | |||
| with FileLock(lock_path): | |||
| temp_file_name = next(tempfile._get_candidate_names()) | |||
| http_get_file( | |||
| url_to_download, | |||
| temporary_cache_dir, | |||
| temp_file_name, | |||
| headers=headers, | |||
| cookies=None if cookies is None else cookies.get_dict()) | |||
| return cache.put_file( | |||
| file_to_download_info, | |||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||
| temp_file_name = next(tempfile._get_candidate_names()) | |||
| http_get_file( | |||
| url_to_download, | |||
| temporary_cache_dir, | |||
| temp_file_name, | |||
| headers=headers, | |||
| cookies=None if cookies is None else cookies.get_dict()) | |||
| temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) | |||
| # for download with commit we can't get Sha256 | |||
| if file_to_download_info[FILE_HASH] is not None: | |||
| file_integrity_validation(temp_file_path, | |||
| file_to_download_info[FILE_HASH]) | |||
| return cache.put_file(file_to_download_info, | |||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
| @@ -222,6 +228,7 @@ def http_get_file( | |||
| http headers to carry necessary info when requesting the remote file | |||
| """ | |||
| total = -1 | |||
| temp_file_manager = partial( | |||
| tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | |||
| @@ -250,4 +257,12 @@ def http_get_file( | |||
| progress.close() | |||
| logger.info('storing %s in cache at %s', url, local_dir) | |||
| downloaded_length = os.path.getsize(temp_file.name) | |||
| if total != downloaded_length: | |||
| os.remove(temp_file.name) | |||
| msg = 'File %s download incomplete, content_length: %s but the \ | |||
| file downloaded length: %s, please download again' % ( | |||
| file_name, total, downloaded_length) | |||
| logger.error(msg) | |||
| raise FileDownloadError(msg) | |||
| os.replace(temp_file.name, os.path.join(local_dir, file_name)) | |||
| @@ -6,11 +6,13 @@ from typing import Dict, Optional, Union | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .constants import FILE_HASH | |||
| from .errors import NotExistError | |||
| from .file_download import (get_file_download_url, http_get_file, | |||
| http_user_agent) | |||
| from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| model_id_to_group_owner_name) | |||
| logger = get_logger() | |||
| @@ -127,9 +129,11 @@ def snapshot_download(model_id: str, | |||
| file_name=model_file['Name'], | |||
| headers=headers, | |||
| cookies=cookies) | |||
| # check file integrity | |||
| temp_file = os.path.join(temp_cache_dir, model_file['Name']) | |||
| if FILE_HASH in model_file: | |||
| file_integrity_validation(temp_file, model_file[FILE_HASH]) | |||
| # put file to cache | |||
| cache.put_file( | |||
| model_file, os.path.join(temp_cache_dir, | |||
| model_file['Name'])) | |||
| cache.put_file(model_file, temp_file) | |||
| return os.path.join(cache.get_root_location()) | |||
| @@ -1,10 +1,15 @@ | |||
| import hashlib | |||
| import os | |||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
| DEFAULT_MODELSCOPE_GROUP, | |||
| MODEL_ID_SEPARATOR, | |||
| MODELSCOPE_URL_SCHEME) | |||
| from modelscope.hub.errors import FileIntegrityError | |||
| from modelscope.utils.file_utils import get_default_cache_dir | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def model_id_to_group_owner_name(model_id): | |||
| @@ -31,3 +36,34 @@ def get_endpoint(): | |||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
| DEFAULT_MODELSCOPE_DOMAIN) | |||
| return MODELSCOPE_URL_SCHEME + modelscope_domain | |||
| def compute_hash(file_path): | |||
| BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||
| sha256_hash = hashlib.sha256() | |||
| with open(file_path, 'rb') as f: | |||
| while True: | |||
| data = f.read(BUFFER_SIZE) | |||
| if not data: | |||
| break | |||
| sha256_hash.update(data) | |||
| return sha256_hash.hexdigest() | |||
| def file_integrity_validation(file_path, expected_sha256): | |||
| """Validate the file hash is expected, if not, delete the file | |||
| Args: | |||
| file_path (str): The file to validate | |||
| expected_sha256 (str): The expected sha256 hash | |||
| Raises: | |||
| FileIntegrityError: If file_path hash is not expected. | |||
| """ | |||
| file_sha256 = compute_hash(file_path) | |||
| if not file_sha256 == expected_sha256: | |||
| os.remove(file_path) | |||
| msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||
| logger.error(msg) | |||
| raise FileIntegrityError(msg) | |||