添加文件下载完整性验证
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9700279
* [to #43913168]fix: add file download integrity check
master
| @@ -4,7 +4,7 @@ DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DO | |||||
| DEFAULT_MODELSCOPE_GROUP = 'damo' | DEFAULT_MODELSCOPE_GROUP = 'damo' | ||||
| MODEL_ID_SEPARATOR = '/' | MODEL_ID_SEPARATOR = '/' | ||||
| FILE_HASH = 'Sha256' | |||||
| LOGGER_NAME = 'ModelScopeHub' | LOGGER_NAME = 'ModelScopeHub' | ||||
| DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | DEFAULT_CREDENTIALS_PATH = '~/.modelscope/credentials' | ||||
| API_RESPONSE_FIELD_DATA = 'Data' | API_RESPONSE_FIELD_DATA = 'Data' | ||||
| @@ -23,6 +23,14 @@ class NotLoginException(Exception): | |||||
| pass | pass | ||||
| class FileIntegrityError(Exception): | |||||
| pass | |||||
| class FileDownloadError(Exception): | |||||
| pass | |||||
| def is_ok(rsp): | def is_ok(rsp): | ||||
| """ Check the request is ok | """ Check the request is ok | ||||
| @@ -16,10 +16,11 @@ from modelscope import __version__ | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
| from .errors import NotExistError | |||||
| from .constants import FILE_HASH | |||||
| from .errors import FileDownloadError, NotExistError | |||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| from .utils.utils import (get_cache_dir, get_endpoint, | |||||
| model_id_to_group_owner_name) | |||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||||
| get_endpoint, model_id_to_group_owner_name) | |||||
| SESSION_ID = uuid4().hex | SESSION_ID = uuid4().hex | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -143,24 +144,29 @@ def model_file_download( | |||||
| # we need to download again | # we need to download again | ||||
| url_to_download = get_file_download_url(model_id, file_path, revision) | url_to_download = get_file_download_url(model_id, file_path, revision) | ||||
| file_to_download_info = { | file_to_download_info = { | ||||
| 'Path': file_path, | |||||
| 'Path': | |||||
| file_path, | |||||
| 'Revision': | 'Revision': | ||||
| revision if is_commit_id else file_to_download_info['Revision'] | |||||
| revision if is_commit_id else file_to_download_info['Revision'], | |||||
| FILE_HASH: | |||||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||||
| file_to_download_info[FILE_HASH] | |||||
| } | } | ||||
| # Prevent parallel downloads of the same file with a lock. | |||||
| lock_path = cache.get_root_location() + '.lock' | |||||
| with FileLock(lock_path): | |||||
| temp_file_name = next(tempfile._get_candidate_names()) | |||||
| http_get_file( | |||||
| url_to_download, | |||||
| temporary_cache_dir, | |||||
| temp_file_name, | |||||
| headers=headers, | |||||
| cookies=None if cookies is None else cookies.get_dict()) | |||||
| return cache.put_file( | |||||
| file_to_download_info, | |||||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||||
| temp_file_name = next(tempfile._get_candidate_names()) | |||||
| http_get_file( | |||||
| url_to_download, | |||||
| temporary_cache_dir, | |||||
| temp_file_name, | |||||
| headers=headers, | |||||
| cookies=None if cookies is None else cookies.get_dict()) | |||||
| temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) | |||||
| # for download with commit we can't get Sha256 | |||||
| if file_to_download_info[FILE_HASH] is not None: | |||||
| file_integrity_validation(temp_file_path, | |||||
| file_to_download_info[FILE_HASH]) | |||||
| return cache.put_file(file_to_download_info, | |||||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | ||||
| @@ -222,6 +228,7 @@ def http_get_file( | |||||
| http headers to carry necessary info when requesting the remote file | http headers to carry necessary info when requesting the remote file | ||||
| """ | """ | ||||
| total = -1 | |||||
| temp_file_manager = partial( | temp_file_manager = partial( | ||||
| tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) | ||||
| @@ -250,4 +257,12 @@ def http_get_file( | |||||
| progress.close() | progress.close() | ||||
| logger.info('storing %s in cache at %s', url, local_dir) | logger.info('storing %s in cache at %s', url, local_dir) | ||||
| downloaded_length = os.path.getsize(temp_file.name) | |||||
| if total != downloaded_length: | |||||
| os.remove(temp_file.name) | |||||
| msg = 'File %s download incomplete, content_length: %s but the \ | |||||
| file downloaded length: %s, please download again' % ( | |||||
| file_name, total, downloaded_length) | |||||
| logger.error(msg) | |||||
| raise FileDownloadError(msg) | |||||
| os.replace(temp_file.name, os.path.join(local_dir, file_name)) | os.replace(temp_file.name, os.path.join(local_dir, file_name)) | ||||
| @@ -6,11 +6,13 @@ from typing import Dict, Optional, Union | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
| from .constants import FILE_HASH | |||||
| from .errors import NotExistError | from .errors import NotExistError | ||||
| from .file_download import (get_file_download_url, http_get_file, | from .file_download import (get_file_download_url, http_get_file, | ||||
| http_user_agent) | http_user_agent) | ||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| from .utils.utils import get_cache_dir, model_id_to_group_owner_name | |||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||||
| model_id_to_group_owner_name) | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -127,9 +129,11 @@ def snapshot_download(model_id: str, | |||||
| file_name=model_file['Name'], | file_name=model_file['Name'], | ||||
| headers=headers, | headers=headers, | ||||
| cookies=cookies) | cookies=cookies) | ||||
| # check file integrity | |||||
| temp_file = os.path.join(temp_cache_dir, model_file['Name']) | |||||
| if FILE_HASH in model_file: | |||||
| file_integrity_validation(temp_file, model_file[FILE_HASH]) | |||||
| # put file to cache | # put file to cache | ||||
| cache.put_file( | |||||
| model_file, os.path.join(temp_cache_dir, | |||||
| model_file['Name'])) | |||||
| cache.put_file(model_file, temp_file) | |||||
| return os.path.join(cache.get_root_location()) | return os.path.join(cache.get_root_location()) | ||||
| @@ -1,10 +1,15 @@ | |||||
| import hashlib | |||||
| import os | import os | ||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | ||||
| DEFAULT_MODELSCOPE_GROUP, | DEFAULT_MODELSCOPE_GROUP, | ||||
| MODEL_ID_SEPARATOR, | MODEL_ID_SEPARATOR, | ||||
| MODELSCOPE_URL_SCHEME) | MODELSCOPE_URL_SCHEME) | ||||
| from modelscope.hub.errors import FileIntegrityError | |||||
| from modelscope.utils.file_utils import get_default_cache_dir | from modelscope.utils.file_utils import get_default_cache_dir | ||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def model_id_to_group_owner_name(model_id): | def model_id_to_group_owner_name(model_id): | ||||
| @@ -31,3 +36,34 @@ def get_endpoint(): | |||||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | ||||
| DEFAULT_MODELSCOPE_DOMAIN) | DEFAULT_MODELSCOPE_DOMAIN) | ||||
| return MODELSCOPE_URL_SCHEME + modelscope_domain | return MODELSCOPE_URL_SCHEME + modelscope_domain | ||||
| def compute_hash(file_path): | |||||
| BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||||
| sha256_hash = hashlib.sha256() | |||||
| with open(file_path, 'rb') as f: | |||||
| while True: | |||||
| data = f.read(BUFFER_SIZE) | |||||
| if not data: | |||||
| break | |||||
| sha256_hash.update(data) | |||||
| return sha256_hash.hexdigest() | |||||
| def file_integrity_validation(file_path, expected_sha256): | |||||
| """Validate the file hash is expected, if not, delete the file | |||||
| Args: | |||||
| file_path (str): The file to validate | |||||
| expected_sha256 (str): The expected sha256 hash | |||||
| Raises: | |||||
| FileIntegrityError: If file_path hash is not expected. | |||||
| """ | |||||
| file_sha256 = compute_hash(file_path) | |||||
| if not file_sha256 == expected_sha256: | |||||
| os.remove(file_path) | |||||
| msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path | |||||
| logger.error(msg) | |||||
| raise FileIntegrityError(msg) | |||||