From 83c2d33aaf7489ed2e1ca2323ebb5d6124a2c81b Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 5 Dec 2023 20:44:40 +0800 Subject: [PATCH] [MNT] modify GetData --- learnware/tests/data.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/learnware/tests/data.py b/learnware/tests/data.py index 422392c..30c6c81 100644 --- a/learnware/tests/data.py +++ b/learnware/tests/data.py @@ -1,3 +1,40 @@ +import json +import requests +from tqdm import tqdm + +from ..config import C + + class GetData: - pass \ No newline at end of file + def __init__(self, host=None, chunk_size=1024 * 1024): + self.headers = None + + if host is None: + self.host = C.backend_host + else: + self.host = host + + self.chunk_size = chunk_size + + def download_file(self, file_path: str, save_path: str): + url = f"{self.host}/engine/download" + + response = requests.get( + url, + params={ + "file_path": file_path, + }, + stream=True, + ) + + if response.status_code != 200: + raise Exception("download failed: " + json.dumps(response.json())) + + num_chunks = int(response.headers["Content-Length"]) // self.chunk_size + 1 + bar = tqdm(total=num_chunks, desc="Downloading", unit="MB") + + with open(save_path, "wb") as f: + for chunk in response.iter_content(chunk_size=self.chunk_size): + f.write(chunk) + bar.update(1) \ No newline at end of file