Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9672696 * [to #43887377]fix: sdk api concurrent call snapshort download file will conflictmaster
| @@ -79,6 +79,8 @@ def model_file_download( | |||||
| cache_dir = get_cache_dir() | cache_dir = get_cache_dir() | ||||
| if isinstance(cache_dir, Path): | if isinstance(cache_dir, Path): | ||||
| cache_dir = str(cache_dir) | cache_dir = str(cache_dir) | ||||
| temporary_cache_dir = os.path.join(cache_dir, 'temp') | |||||
| os.makedirs(temporary_cache_dir, exist_ok=True) | |||||
| group_or_owner, name = model_id_to_group_owner_name(model_id) | group_or_owner, name = model_id_to_group_owner_name(model_id) | ||||
| @@ -152,12 +154,13 @@ def model_file_download( | |||||
| temp_file_name = next(tempfile._get_candidate_names()) | temp_file_name = next(tempfile._get_candidate_names()) | ||||
| http_get_file( | http_get_file( | ||||
| url_to_download, | url_to_download, | ||||
| cache_dir, | |||||
| temporary_cache_dir, | |||||
| temp_file_name, | temp_file_name, | ||||
| headers=headers, | headers=headers, | ||||
| cookies=None if cookies is None else cookies.get_dict()) | cookies=None if cookies is None else cookies.get_dict()) | ||||
| return cache.put_file(file_to_download_info, | |||||
| os.path.join(cache_dir, temp_file_name)) | |||||
| return cache.put_file( | |||||
| file_to_download_info, | |||||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | ||||
| @@ -1,4 +1,5 @@ | |||||
| import os | import os | ||||
| import tempfile | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| @@ -58,6 +59,8 @@ def snapshot_download(model_id: str, | |||||
| cache_dir = get_cache_dir() | cache_dir = get_cache_dir() | ||||
| if isinstance(cache_dir, Path): | if isinstance(cache_dir, Path): | ||||
| cache_dir = str(cache_dir) | cache_dir = str(cache_dir) | ||||
| temporary_cache_dir = os.path.join(cache_dir, 'temp') | |||||
| os.makedirs(temporary_cache_dir, exist_ok=True) | |||||
| group_or_owner, name = model_id_to_group_owner_name(model_id) | group_or_owner, name = model_id_to_group_owner_name(model_id) | ||||
| @@ -98,31 +101,35 @@ def snapshot_download(model_id: str, | |||||
| headers=snapshot_header, | headers=snapshot_header, | ||||
| ) | ) | ||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||||
| if cache.exists(model_file): | |||||
| file_name = os.path.basename(model_file['Name']) | |||||
| logger.info( | |||||
| f'File {file_name} already in cache, skip downloading!') | |||||
| continue | |||||
| with tempfile.TemporaryDirectory( | |||||
| dir=temporary_cache_dir) as temp_cache_dir: | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||||
| if cache.exists(model_file): | |||||
| file_name = os.path.basename(model_file['Name']) | |||||
| logger.info( | |||||
| f'File {file_name} already in cache, skip downloading!' | |||||
| ) | |||||
| continue | |||||
| # get download url | |||||
| url = get_file_download_url( | |||||
| model_id=model_id, | |||||
| file_path=model_file['Path'], | |||||
| revision=revision) | |||||
| # get download url | |||||
| url = get_file_download_url( | |||||
| model_id=model_id, | |||||
| file_path=model_file['Path'], | |||||
| revision=revision) | |||||
| # First download to /tmp | |||||
| http_get_file( | |||||
| url=url, | |||||
| local_dir=cache_dir, | |||||
| file_name=model_file['Name'], | |||||
| headers=headers, | |||||
| cookies=cookies) | |||||
| # put file to cache | |||||
| cache.put_file(model_file, | |||||
| os.path.join(cache_dir, model_file['Name'])) | |||||
| # First download to /tmp | |||||
| http_get_file( | |||||
| url=url, | |||||
| local_dir=temp_cache_dir, | |||||
| file_name=model_file['Name'], | |||||
| headers=headers, | |||||
| cookies=cookies) | |||||
| # put file to cache | |||||
| cache.put_file( | |||||
| model_file, os.path.join(temp_cache_dir, | |||||
| model_file['Name'])) | |||||
| return os.path.join(cache.get_root_location()) | return os.path.join(cache.get_root_location()) | ||||
| @@ -21,9 +21,6 @@ DEFAULT_GIT_PATH = 'git' | |||||
| download_model_file_name = 'test.bin' | download_model_file_name = 'test.bin' | ||||
| @unittest.skip( | |||||
| "Access token is always change, we can't login with same access token, so skip!" | |||||
| ) | |||||
| class HubOperationTest(unittest.TestCase): | class HubOperationTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| @@ -18,9 +18,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||||
| delete_credential) | delete_credential) | ||||
| @unittest.skip( | |||||
| "Access token is always change, we can't login with same access token, so skip!" | |||||
| ) | |||||
| class HubPrivateFileDownloadTest(unittest.TestCase): | class HubPrivateFileDownloadTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| @@ -15,9 +15,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, | |||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| @unittest.skip( | |||||
| "Access token is always change, we can't login with same access token, so skip!" | |||||
| ) | |||||
| class HubPrivateRepositoryTest(unittest.TestCase): | class HubPrivateRepositoryTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| @@ -24,9 +24,6 @@ logger.setLevel('DEBUG') | |||||
| DEFAULT_GIT_PATH = 'git' | DEFAULT_GIT_PATH = 'git' | ||||
| @unittest.skip( | |||||
| "Access token is always change, we can't login with same access token, so skip!" | |||||
| ) | |||||
| class HubRepositoryTest(unittest.TestCase): | class HubRepositoryTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| @@ -6,8 +6,8 @@ from os.path import expanduser | |||||
| from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH | from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH | ||||
| # for user citest and sdkdev | # for user citest and sdkdev | ||||
| TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg==' | |||||
| TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg==' | |||||
| TEST_ACCESS_TOKEN1 = 'RGZZdkh2Z3BlMFU1VktjUkdIcUJtdjdqdnhQUEQrUVROdVBjclAzUGVycHFhU1BFZFBIaGtUOHB1eHQ2OTV3dQ==' | |||||
| TEST_ACCESS_TOKEN2 = 'dFpadllseTZQbHlyK0E4amQxVC84a2RtZHdkUVhmMUl3M1VXZXU4dS9GZlRuVmFUTW5yQm8yTENYWEw2SVh0Uw==' | |||||
| TEST_MODEL_CHINESE_NAME = '内部测试模型' | TEST_MODEL_CHINESE_NAME = '内部测试模型' | ||||
| TEST_MODEL_ORG = 'citest' | TEST_MODEL_ORG = 'citest' | ||||