|
|
|
@@ -1,130 +1,161 @@ |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
import atexit |
|
|
|
import tempfile |
|
|
|
import zipfile |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from typing import Optional, List, Union, Tuple |
|
|
|
from dataclasses import dataclass |
|
|
|
from typing import Tuple, Optional, List, Union |
|
|
|
|
|
|
|
from .config import OnlineBenchmark, online_benchmarks |
|
|
|
from .config import BenchmarkConfig, benchmark_configs |
|
|
|
from ..data import GetData |
|
|
|
from ...config import C |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
class Benchmark: |
|
|
|
learnware_ids: List[str] |
|
|
|
user_num: int |
|
|
|
unlabeled_feature_paths: List[str] |
|
|
|
unlabeled_groudtruths_paths: List[str] |
|
|
|
labeled_feature_paths: Optional[List[str]] = None |
|
|
|
labeled_label_paths: Optional[List[str]] = None |
|
|
|
test_X_paths: List[str] |
|
|
|
test_y_paths: List[str] |
|
|
|
train_X_paths: Optional[List[str]] = None |
|
|
|
train_y_paths: Optional[List[str]] = None |
|
|
|
extra_info_path: Optional[str] = None |
|
|
|
|
|
|
|
# TODO: add more method for benchmark |
|
|
|
|
|
|
|
def get_unlabeled_data(self, user_ids: Union[str, List[str]]): |
|
|
|
|
|
|
|
def get_test_data(self, user_ids: Union[str, List[str]]): |
|
|
|
if isinstance(user_ids, str): |
|
|
|
user_ids = [user_ids] |
|
|
|
|
|
|
|
|
|
|
|
ret = [] |
|
|
|
for user_id in user_ids: |
|
|
|
with open(self.unlabeled_feature_paths[user_id], "rb") as fin: |
|
|
|
unlabeled_feature = pickle.load(fin) |
|
|
|
|
|
|
|
with open(self.unlabeled_groudtruths_paths[user_id], "rb") as fin: |
|
|
|
unlabeled_groudtruth = pickle.load(fin) |
|
|
|
|
|
|
|
ret.append((unlabeled_feature, unlabeled_groudtruth)) |
|
|
|
with open(self.test_X_paths[user_id], "rb") as fin: |
|
|
|
test_X = pickle.load(fin) |
|
|
|
|
|
|
|
with open(self.test_y_paths[user_id], "rb") as fin: |
|
|
|
test_y = pickle.load(fin) |
|
|
|
|
|
|
|
ret.append((test_X, test_y)) |
|
|
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
def get_labeled_data(self, user_ids): |
|
|
|
if self.labeled_feature_paths is None or self.labeled_label_paths is None: |
|
|
|
|
|
|
|
def get_train_data(self, user_ids): |
|
|
|
if self.train_X_paths is None or self.train_y_paths is None: |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(user_ids, str): |
|
|
|
user_ids = [user_ids] |
|
|
|
|
|
|
|
|
|
|
|
ret = [] |
|
|
|
for user_id in user_ids: |
|
|
|
with open(self.labeled_feature_paths[user_id], "rb") as fin: |
|
|
|
labeled_feature = pickle.load(fin) |
|
|
|
|
|
|
|
with open(self.labeled_label_paths[user_id], "rb") as fin: |
|
|
|
labeled_groudtruth = pickle.load(fin) |
|
|
|
|
|
|
|
ret.append((labeled_feature, labeled_groudtruth)) |
|
|
|
with open(self.train_X_paths[user_id], "rb") as fin: |
|
|
|
train_X = pickle.load(fin) |
|
|
|
|
|
|
|
with open(self.train_y_paths[user_id], "rb") as fin: |
|
|
|
train_y = pickle.load(fin) |
|
|
|
|
|
|
|
ret.append((train_X, train_y)) |
|
|
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LearnwareBenchmark: |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.online_benchmarks = online_benchmarks |
|
|
|
self.tempdir_list = [] |
|
|
|
atexit.register(self.cleanup) |
|
|
|
|
|
|
|
self.benchmark_configs = benchmark_configs |
|
|
|
|
|
|
|
def list_benchmarks(self): |
|
|
|
return list(self.online_benchmarks.keys()) |
|
|
|
|
|
|
|
def get_benchmark(self, online_benchmark: Union[str, OnlineBenchmark]): |
|
|
|
if isinstance(online_benchmark, str): |
|
|
|
online_benchmark = self.online_benchmarks[online_benchmark] |
|
|
|
|
|
|
|
self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_benchmark")) |
|
|
|
save_folder = self.tempdir_list[-1].name |
|
|
|
|
|
|
|
unlabeled_data_localpath = os.path.join(save_folder, "unlabeled_data.zip") |
|
|
|
GetData().download_file(online_benchmark.unlabeled_data_path, unlabeled_data_localpath) |
|
|
|
|
|
|
|
unlabeled_feature_paths = [] |
|
|
|
unlabeled_groudtruth_paths = [] |
|
|
|
|
|
|
|
with zipfile.ZipFile(unlabeled_data_localpath, "r") as z_file: |
|
|
|
unlabeled_data_dirpath = os.path.join(save_folder, "unlabeled_data") |
|
|
|
z_file.extractall(unlabeled_data_dirpath) |
|
|
|
for user_id in range(online_benchmark.user_num): |
|
|
|
user_feature_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_feature.pkl")) |
|
|
|
user_groudtruth_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_groudtruth.pkl")) |
|
|
|
assert os.path.isfile(user_feature_filepath), f"user {user_id} unlabeled feature is not valid!" |
|
|
|
assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} unlabeled groudtruth is not valid!" |
|
|
|
unlabeled_feature_paths.append(user_feature_filepath) |
|
|
|
unlabeled_groudtruth_paths.append(user_groudtruth_filepath) |
|
|
|
|
|
|
|
labeled_feature_paths = None |
|
|
|
labeled_label_paths = None |
|
|
|
if online_benchmark.labeled_data_path is not None: |
|
|
|
labeled_data_localpath = os.path.join(save_folder, "labeled_data.zip") |
|
|
|
GetData().download_file(online_benchmark.labeled_data_path, labeled_data_localpath) |
|
|
|
|
|
|
|
labeled_feature_paths = [] |
|
|
|
labeled_label_paths = [] |
|
|
|
|
|
|
|
with zipfile.ZipFile(labeled_data_localpath, "r") as z_file: |
|
|
|
labeled_data_dirpath = os.path.join(save_folder, "labeled_data") |
|
|
|
z_file.extractall(labeled_data_dirpath) |
|
|
|
for user_id in range(online_benchmark.user_num): |
|
|
|
user_feature_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_feature.pkl")) |
|
|
|
user_groudtruth_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_label.pkl")) |
|
|
|
assert os.path.isfile(user_feature_filepath), f"user {user_id} labeled feature is not valid!" |
|
|
|
assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} labeled label is not valid!" |
|
|
|
labeled_feature_paths.append(user_feature_filepath) |
|
|
|
labeled_label_paths.append(user_groudtruth_filepath) |
|
|
|
|
|
|
|
extra_zip_localpath = None |
|
|
|
if online_benchmark.extra_info_path is not None: |
|
|
|
extra_zip_localpath = os.path.join(save_folder, os.path.basename(online_benchmark.extra_info_path)) |
|
|
|
GetData().download_file(online_benchmark.extra_info_path, extra_zip_localpath) |
|
|
|
|
|
|
|
return list(self.benchmark_configs.keys()) |
|
|
|
|
|
|
|
def _check_cache_data_valid(self, benchmark_config: BenchmarkConfig, data_type: str) -> bool: |
|
|
|
"""Check if the cache data is valid |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
benchmark_config : BenchmarkConfig |
|
|
|
benchmark config |
|
|
|
data_type : str |
|
|
|
"test" for test data or "train" for train data |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
bool |
|
|
|
A flag indicating if the cache data is valid |
|
|
|
""" |
|
|
|
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") |
|
|
|
if os.path.exists(cache_folder): |
|
|
|
for user_id in range(benchmark_config.user_num): |
|
|
|
X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") |
|
|
|
y_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") |
|
|
|
if not os.path.isfile(X_path) or not os.path.isfile(y_path): |
|
|
|
return False |
|
|
|
return True |
|
|
|
else: |
|
|
|
return False |
|
|
|
|
|
|
|
def _download_data(self, download_path: str, save_path: str): |
|
|
|
"""Download data from backend |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
download_path : str |
|
|
|
data path for download in backend |
|
|
|
save_path : str |
|
|
|
local cache path for saving data |
|
|
|
""" |
|
|
|
with tempfile.TemporaryDirectory(prefix="learnware_benchmark_") as tempdir: |
|
|
|
test_data_zippath = os.path.join(tempdir, "benchmark_data.zip") |
|
|
|
GetData().download_file(download_path, test_data_zippath) |
|
|
|
|
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
with zipfile.ZipFile(test_data_zippath, "r") as z_file: |
|
|
|
z_file.extractall(save_path) |
|
|
|
|
|
|
|
def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]): |
|
|
|
"""Load data from local cache path |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
benchmark_config : BenchmarkConfig |
|
|
|
benchmark config |
|
|
|
data_type : str |
|
|
|
"test" for test data or "train" for train data |
|
|
|
""" |
|
|
|
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") |
|
|
|
if not self._check_cache_data_valid(benchmark_config, data_type): |
|
|
|
download_path = getattr(benchmark_config, f"{data_type}_data_path", None) |
|
|
|
self._download_data(download_path, cache_folder) |
|
|
|
|
|
|
|
X_paths, y_paths = [], [] |
|
|
|
for user_id in range(benchmark_config.user_num): |
|
|
|
user_X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") |
|
|
|
user_y_path = os.path.join(cache_folder, f"user{user_id}_y.pkl") |
|
|
|
assert os.path.isfile(user_X_path), f"user {user_id} {data_type}_X is not valid!" |
|
|
|
assert os.path.isfile(user_y_path), f"user {user_id} {data_type}_y is not valid!" |
|
|
|
X_paths.append(user_X_path) |
|
|
|
y_paths.append(user_y_path) |
|
|
|
|
|
|
|
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): |
|
|
|
if isinstance(benchmark_config, str): |
|
|
|
benchmark_config = self.benchmark_configs[benchmark_config] |
|
|
|
|
|
|
|
# Load test data |
|
|
|
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") |
|
|
|
|
|
|
|
# Load train data |
|
|
|
train_X_paths, train_y_paths = None, None |
|
|
|
if benchmark_config.train_data_path is not None: |
|
|
|
train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train") |
|
|
|
|
|
|
|
# Load extra info |
|
|
|
extra_info_path = None |
|
|
|
if benchmark_config.extra_info_path is not None: |
|
|
|
extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info") |
|
|
|
if not os.path.exists(extra_info_path): |
|
|
|
self._download_data(benchmark_config.extra_info_path, extra_info_path) |
|
|
|
|
|
|
|
return Benchmark( |
|
|
|
learnware_ids=online_benchmark.learnware_ids, |
|
|
|
user_num=online_benchmark.user_num, |
|
|
|
unlabeled_feature_paths=unlabeled_feature_paths, |
|
|
|
unlabeled_groudtruths_paths=unlabeled_groudtruth_paths, |
|
|
|
labeled_feature_paths=labeled_feature_paths, |
|
|
|
labeled_label_paths=labeled_label_paths, |
|
|
|
extra_info_path=extra_zip_localpath, |
|
|
|
learnware_ids=benchmark_config.learnware_ids, |
|
|
|
user_num=benchmark_config.user_num, |
|
|
|
test_X_paths=test_X_paths, |
|
|
|
test_y_paths=test_y_paths, |
|
|
|
train_X_paths=train_X_paths, |
|
|
|
train_y_paths=train_y_paths, |
|
|
|
extra_info_path=extra_info_path, |
|
|
|
) |
|
|
|
|
|
|
|
def cleanup(self): |
|
|
|
for tempdir in self.tempdir_list: |
|
|
|
tempdir.cleanup() |