diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index af323081..d0b8a102 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -79,6 +79,8 @@ def model_file_download( cache_dir = get_cache_dir() if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + temporary_cache_dir = os.path.join(cache_dir, 'temp') + os.makedirs(temporary_cache_dir, exist_ok=True) group_or_owner, name = model_id_to_group_owner_name(model_id) @@ -152,12 +154,13 @@ def model_file_download( temp_file_name = next(tempfile._get_candidate_names()) http_get_file( url_to_download, - cache_dir, + temporary_cache_dir, temp_file_name, headers=headers, cookies=None if cookies is None else cookies.get_dict()) - return cache.put_file(file_to_download_info, - os.path.join(cache_dir, temp_file_name)) + return cache.put_file( + file_to_download_info, + os.path.join(temporary_cache_dir, temp_file_name)) def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 655806dc..5f9548e9 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -1,4 +1,5 @@ import os +import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -58,6 +59,8 @@ def snapshot_download(model_id: str, cache_dir = get_cache_dir() if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + temporary_cache_dir = os.path.join(cache_dir, 'temp') + os.makedirs(temporary_cache_dir, exist_ok=True) group_or_owner, name = model_id_to_group_owner_name(model_id) @@ -98,31 +101,35 @@ def snapshot_download(model_id: str, headers=snapshot_header, ) - for model_file in model_files: - if model_file['Type'] == 'tree': - continue - # check model_file is exist in cache, if exist, skip download, otherwise download - if cache.exists(model_file): - file_name = os.path.basename(model_file['Name']) - logger.info( - f'File {file_name} already in cache, skip downloading!') - continue + with tempfile.TemporaryDirectory( + dir=temporary_cache_dir) as temp_cache_dir: + for model_file in model_files: + if model_file['Type'] == 'tree': + continue + # check model_file is exist in cache, if exist, skip download, otherwise download + if cache.exists(model_file): + file_name = os.path.basename(model_file['Name']) + logger.info( + f'File {file_name} already in cache, skip downloading!' + ) + continue - # get download url - url = get_file_download_url( - model_id=model_id, - file_path=model_file['Path'], - revision=revision) + # get download url + url = get_file_download_url( + model_id=model_id, + file_path=model_file['Path'], + revision=revision) - # First download to /tmp - http_get_file( - url=url, - local_dir=cache_dir, - file_name=model_file['Name'], - headers=headers, - cookies=cookies) - # put file to cache - cache.put_file(model_file, - os.path.join(cache_dir, model_file['Name'])) + # First download to /tmp + http_get_file( + url=url, + local_dir=temp_cache_dir, + file_name=model_file['Name'], + headers=headers, + cookies=cookies) + # put file to cache + cache.put_file( + model_file, os.path.join(temp_cache_dir, + model_file['Name'])) return os.path.join(cache.get_root_location()) diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index 1cad1c2b..636e987e 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -21,9 +21,6 @@ DEFAULT_GIT_PATH = 'git' download_model_file_name = 'test.bin' -@unittest.skip( - "Access token is always change, we can't login with same access token, so skip!" -) class HubOperationTest(unittest.TestCase): def setUp(self): diff --git a/tests/hub/test_hub_private_files.py b/tests/hub/test_hub_private_files.py index 70015d96..d19a7c64 100644 --- a/tests/hub/test_hub_private_files.py +++ b/tests/hub/test_hub_private_files.py @@ -18,9 +18,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, delete_credential) -@unittest.skip( - "Access token is always change, we can't login with same access token, so skip!" -) class HubPrivateFileDownloadTest(unittest.TestCase): def setUp(self): diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py index 8a614b30..8683a884 100644 --- a/tests/hub/test_hub_private_repository.py +++ b/tests/hub/test_hub_private_repository.py @@ -15,9 +15,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, DEFAULT_GIT_PATH = 'git' -@unittest.skip( - "Access token is always change, we can't login with same access token, so skip!" -) class HubPrivateRepositoryTest(unittest.TestCase): def setUp(self): diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py index b0e7237d..9dfe8efd 100644 --- a/tests/hub/test_hub_repository.py +++ b/tests/hub/test_hub_repository.py @@ -24,9 +24,6 @@ logger.setLevel('DEBUG') DEFAULT_GIT_PATH = 'git' -@unittest.skip( - "Access token is always change, we can't login with same access token, so skip!" -) class HubRepositoryTest(unittest.TestCase): def setUp(self): diff --git a/tests/hub/test_utils.py b/tests/hub/test_utils.py index 2b3a184e..adb3e566 100644 --- a/tests/hub/test_utils.py +++ b/tests/hub/test_utils.py @@ -6,8 +6,8 @@ from os.path import expanduser from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH # for user citest and sdkdev -TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg==' -TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg==' +TEST_ACCESS_TOKEN1 = 'RGZZdkh2Z3BlMFU1VktjUkdIcUJtdjdqdnhQUEQrUVROdVBjclAzUGVycHFhU1BFZFBIaGtUOHB1eHQ2OTV3dQ==' +TEST_ACCESS_TOKEN2 = 'dFpadllseTZQbHlyK0E4amQxVC84a2RtZHdkUVhmMUl3M1VXZXU4dS9GZlRuVmFUTW5yQm8yTENYWEw2SVh0Uw==' TEST_MODEL_CHINESE_NAME = '内部测试模型' TEST_MODEL_ORG = 'citest'