From ccfccfa86843090e8bc2f5a6a2a7ea5800d5b071 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 8 Dec 2023 23:42:26 +0800 Subject: [PATCH] [MNT] add cache in benchmark --- learnware/tests/benchmarks/__init__.py | 231 ++++++++++++++----------- learnware/tests/benchmarks/config.py | 25 +-- 2 files changed, 145 insertions(+), 111 deletions(-) diff --git a/learnware/tests/benchmarks/__init__.py b/learnware/tests/benchmarks/__init__.py index 715037d..7caa6c7 100644 --- a/learnware/tests/benchmarks/__init__.py +++ b/learnware/tests/benchmarks/__init__.py @@ -1,130 +1,161 @@ import os import pickle -import atexit import tempfile import zipfile -from dataclasses import dataclass, field -from typing import Optional, List, Union, Tuple +from dataclasses import dataclass +from typing import Tuple, Optional, List, Union -from .config import OnlineBenchmark, online_benchmarks +from .config import BenchmarkConfig, benchmark_configs from ..data import GetData +from ...config import C + + @dataclass class Benchmark: learnware_ids: List[str] user_num: int - unlabeled_feature_paths: List[str] - unlabeled_groudtruths_paths: List[str] - labeled_feature_paths: Optional[List[str]] = None - labeled_label_paths: Optional[List[str]] = None + test_X_paths: List[str] + test_y_paths: List[str] + train_X_paths: Optional[List[str]] = None + train_y_paths: Optional[List[str]] = None extra_info_path: Optional[str] = None - - # TODO: add more method for benchmark - - def get_unlabeled_data(self, user_ids: Union[str, List[str]]): + + def get_test_data(self, user_ids: Union[str, List[str]]): if isinstance(user_ids, str): user_ids = [user_ids] - + ret = [] for user_id in user_ids: - with open(self.unlabeled_feature_paths[user_id], "rb") as fin: - unlabeled_feature = pickle.load(fin) - - with open(self.unlabeled_groudtruths_paths[user_id], "rb") as fin: - unlabeled_groudtruth = pickle.load(fin) - - ret.append((unlabeled_feature, unlabeled_groudtruth)) + with open(self.test_X_paths[user_id], "rb") as fin: + test_X = pickle.load(fin) + + with open(self.test_y_paths[user_id], "rb") as fin: + test_y = pickle.load(fin) + + ret.append((test_X, test_y)) return ret - - def get_labeled_data(self, user_ids): - if self.labeled_feature_paths is None or self.labeled_label_paths is None: + + def get_train_data(self, user_ids): + if self.train_X_paths is None or self.train_y_paths is None: return None - + if isinstance(user_ids, str): user_ids = [user_ids] - + ret = [] for user_id in user_ids: - with open(self.labeled_feature_paths[user_id], "rb") as fin: - labeled_feature = pickle.load(fin) - - with open(self.labeled_label_paths[user_id], "rb") as fin: - labeled_groudtruth = pickle.load(fin) - - ret.append((labeled_feature, labeled_groudtruth)) + with open(self.train_X_paths[user_id], "rb") as fin: + train_X = pickle.load(fin) + + with open(self.train_y_paths[user_id], "rb") as fin: + train_y = pickle.load(fin) + + ret.append((train_X, train_y)) return ret - + class LearnwareBenchmark: - def __init__(self): - self.online_benchmarks = online_benchmarks - self.tempdir_list = [] - atexit.register(self.cleanup) - + self.benchmark_configs = benchmark_configs + def list_benchmarks(self): - return list(self.online_benchmarks.keys()) - - def get_benchmark(self, online_benchmark: Union[str, OnlineBenchmark]): - if isinstance(online_benchmark, str): - online_benchmark = self.online_benchmarks[online_benchmark] - - self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_benchmark")) - save_folder = self.tempdir_list[-1].name - - unlabeled_data_localpath = os.path.join(save_folder, "unlabeled_data.zip") - GetData().download_file(online_benchmark.unlabeled_data_path, unlabeled_data_localpath) - - unlabeled_feature_paths = [] - unlabeled_groudtruth_paths = [] - - with zipfile.ZipFile(unlabeled_data_localpath, "r") as z_file: - unlabeled_data_dirpath = os.path.join(save_folder, "unlabeled_data") - z_file.extractall(unlabeled_data_dirpath) - for user_id in range(online_benchmark.user_num): - user_feature_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_feature.pkl")) - user_groudtruth_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_groudtruth.pkl")) - assert os.path.isfile(user_feature_filepath), f"user {user_id} unlabeled feature is not valid!" - assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} unlabeled groudtruth is not valid!" - unlabeled_feature_paths.append(user_feature_filepath) - unlabeled_groudtruth_paths.append(user_groudtruth_filepath) - - labeled_feature_paths = None - labeled_label_paths = None - if online_benchmark.labeled_data_path is not None: - labeled_data_localpath = os.path.join(save_folder, "labeled_data.zip") - GetData().download_file(online_benchmark.labeled_data_path, labeled_data_localpath) - - labeled_feature_paths = [] - labeled_label_paths = [] - - with zipfile.ZipFile(labeled_data_localpath, "r") as z_file: - labeled_data_dirpath = os.path.join(save_folder, "labeled_data") - z_file.extractall(labeled_data_dirpath) - for user_id in range(online_benchmark.user_num): - user_feature_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_feature.pkl")) - user_groudtruth_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_label.pkl")) - assert os.path.isfile(user_feature_filepath), f"user {user_id} labeled feature is not valid!" - assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} labeled label is not valid!" - labeled_feature_paths.append(user_feature_filepath) - labeled_label_paths.append(user_groudtruth_filepath) - - extra_zip_localpath = None - if online_benchmark.extra_info_path is not None: - extra_zip_localpath = os.path.join(save_folder, os.path.basename(online_benchmark.extra_info_path)) - GetData().download_file(online_benchmark.extra_info_path, extra_zip_localpath) - + return list(self.benchmark_configs.keys()) + + def _check_cache_data_valid(self, benchmark_config: BenchmarkConfig, data_type: str) -> bool: + """Check if the cache data is valid + + Parameters + ---------- + benchmark_config : BenchmarkConfig + benchmark config + data_type : str + "test" for test data or "train" for train data + + Returns + ------- + bool + A flag indicating if the cache data is valid + """ + cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") + if os.path.exists(cache_folder): + for user_id in range(benchmark_config.user_num): + X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") + y_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") + if not os.path.isfile(X_path) or not os.path.isfile(y_path): + return False + return True + else: + return False + + def _download_data(self, download_path: str, save_path: str): + """Download data from backend + + Parameters + ---------- + download_path : str + data path for download in backend + save_path : str + local cache path for saving data + """ + with tempfile.TemporaryDirectory(prefix="learnware_benchmark_") as tempdir: + test_data_zippath = os.path.join(tempdir, "benchmark_data.zip") + GetData().download_file(download_path, test_data_zippath) + + os.makedirs(save_path, exist_ok=True) + with zipfile.ZipFile(test_data_zippath, "r") as z_file: + z_file.extractall(save_path) + + def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]): + """Load data from local cache path + + Parameters + ---------- + benchmark_config : BenchmarkConfig + benchmark config + data_type : str + "test" for test data or "train" for train data + """ + cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") + if not self._check_cache_data_valid(benchmark_config, data_type): + download_path = getattr(benchmark_config, f"{data_type}_data_path", None) + self._download_data(download_path, cache_folder) + + X_paths, y_paths = [], [] + for user_id in range(benchmark_config.user_num): + user_X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") + user_y_path = os.path.join(cache_folder, f"user{user_id}_y.pkl") + assert os.path.isfile(user_X_path), f"user {user_id} {data_type}_X is not valid!" + assert os.path.isfile(user_y_path), f"user {user_id} {data_type}_y is not valid!" + X_paths.append(user_X_path) + y_paths.append(user_y_path) + + def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): + if isinstance(benchmark_config, str): + benchmark_config = self.benchmark_configs[benchmark_config] + + # Load test data + test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") + + # Load train data + train_X_paths, train_y_paths = None, None + if benchmark_config.train_data_path is not None: + train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train") + + # Load extra info + extra_info_path = None + if benchmark_config.extra_info_path is not None: + extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info") + if not os.path.exists(extra_info_path): + self._download_data(benchmark_config.extra_info_path, extra_info_path) + return Benchmark( - learnware_ids=online_benchmark.learnware_ids, - user_num=online_benchmark.user_num, - unlabeled_feature_paths=unlabeled_feature_paths, - unlabeled_groudtruths_paths=unlabeled_groudtruth_paths, - labeled_feature_paths=labeled_feature_paths, - labeled_label_paths=labeled_label_paths, - extra_info_path=extra_zip_localpath, + learnware_ids=benchmark_config.learnware_ids, + user_num=benchmark_config.user_num, + test_X_paths=test_X_paths, + test_y_paths=test_y_paths, + train_X_paths=train_X_paths, + train_y_paths=train_y_paths, + extra_info_path=extra_info_path, ) - - def cleanup(self): - for tempdir in self.tempdir_list: - tempdir.cleanup() \ No newline at end of file diff --git a/learnware/tests/benchmarks/config.py b/learnware/tests/benchmarks/config.py index a523a55..289cb50 100644 --- a/learnware/tests/benchmarks/config.py +++ b/learnware/tests/benchmarks/config.py @@ -1,21 +1,24 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Optional, List -from ...learnware import Learnware + @dataclass -class OnlineBenchmark: +class BenchmarkConfig: + name: str learnware_ids: List[str] user_num: int - unlabeled_data_path: str - labeled_data_path: Optional[str] = None + test_data_path: str + train_data_path: Optional[str] = None extra_info_path: Optional[str] = None -online_benchmarks = { - "example": OnlineBenchmark( + +benchmark_configs = { + "example": BenchmarkConfig( + name="example", learnware_ids=["00001951", "00001980", "00001987"], user_num=3, - unlabeled_data_path="example_path1", - labeled_data_path="example_path2", - extra_info_path="example_path3" + test_data_path="example_path1", + train_data_path="example_path2", + extra_info_path="example_path3", ) -} \ No newline at end of file +}