Browse Source

[MNT] add cache in benchmark

tags/v0.3.2
Gene 2 years ago
parent
commit
ccfccfa868
2 changed files with 145 additions and 111 deletions
  1. +131
    -100
      learnware/tests/benchmarks/__init__.py
  2. +14
    -11
      learnware/tests/benchmarks/config.py

+ 131
- 100
learnware/tests/benchmarks/__init__.py View File

@@ -1,130 +1,161 @@
import os
import pickle
import atexit
import tempfile
import zipfile
from dataclasses import dataclass, field
from typing import Optional, List, Union, Tuple
from dataclasses import dataclass
from typing import Tuple, Optional, List, Union

from .config import OnlineBenchmark, online_benchmarks
from .config import BenchmarkConfig, benchmark_configs
from ..data import GetData
from ...config import C


@dataclass
class Benchmark:
learnware_ids: List[str]
user_num: int
unlabeled_feature_paths: List[str]
unlabeled_groudtruths_paths: List[str]
labeled_feature_paths: Optional[List[str]] = None
labeled_label_paths: Optional[List[str]] = None
test_X_paths: List[str]
test_y_paths: List[str]
train_X_paths: Optional[List[str]] = None
train_y_paths: Optional[List[str]] = None
extra_info_path: Optional[str] = None
# TODO: add more method for benchmark
def get_unlabeled_data(self, user_ids: Union[str, List[str]]):

def get_test_data(self, user_ids: Union[str, List[str]]):
if isinstance(user_ids, str):
user_ids = [user_ids]
ret = []
for user_id in user_ids:
with open(self.unlabeled_feature_paths[user_id], "rb") as fin:
unlabeled_feature = pickle.load(fin)
with open(self.unlabeled_groudtruths_paths[user_id], "rb") as fin:
unlabeled_groudtruth = pickle.load(fin)
ret.append((unlabeled_feature, unlabeled_groudtruth))
with open(self.test_X_paths[user_id], "rb") as fin:
test_X = pickle.load(fin)
with open(self.test_y_paths[user_id], "rb") as fin:
test_y = pickle.load(fin)
ret.append((test_X, test_y))

return ret
def get_labeled_data(self, user_ids):
if self.labeled_feature_paths is None or self.labeled_label_paths is None:
def get_train_data(self, user_ids):
if self.train_X_paths is None or self.train_y_paths is None:
return None
if isinstance(user_ids, str):
user_ids = [user_ids]
ret = []
for user_id in user_ids:
with open(self.labeled_feature_paths[user_id], "rb") as fin:
labeled_feature = pickle.load(fin)
with open(self.labeled_label_paths[user_id], "rb") as fin:
labeled_groudtruth = pickle.load(fin)
ret.append((labeled_feature, labeled_groudtruth))
with open(self.train_X_paths[user_id], "rb") as fin:
train_X = pickle.load(fin)
with open(self.train_y_paths[user_id], "rb") as fin:
train_y = pickle.load(fin)
ret.append((train_X, train_y))

return ret

class LearnwareBenchmark:
def __init__(self):
self.online_benchmarks = online_benchmarks
self.tempdir_list = []
atexit.register(self.cleanup)
self.benchmark_configs = benchmark_configs

def list_benchmarks(self):
return list(self.online_benchmarks.keys())
def get_benchmark(self, online_benchmark: Union[str, OnlineBenchmark]):
if isinstance(online_benchmark, str):
online_benchmark = self.online_benchmarks[online_benchmark]
self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_benchmark"))
save_folder = self.tempdir_list[-1].name
unlabeled_data_localpath = os.path.join(save_folder, "unlabeled_data.zip")
GetData().download_file(online_benchmark.unlabeled_data_path, unlabeled_data_localpath)
unlabeled_feature_paths = []
unlabeled_groudtruth_paths = []
with zipfile.ZipFile(unlabeled_data_localpath, "r") as z_file:
unlabeled_data_dirpath = os.path.join(save_folder, "unlabeled_data")
z_file.extractall(unlabeled_data_dirpath)
for user_id in range(online_benchmark.user_num):
user_feature_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_feature.pkl"))
user_groudtruth_filepath = os.path.isfile(os.path.join(unlabeled_data_dirpath, f"user{user_id}_groudtruth.pkl"))
assert os.path.isfile(user_feature_filepath), f"user {user_id} unlabeled feature is not valid!"
assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} unlabeled groudtruth is not valid!"
unlabeled_feature_paths.append(user_feature_filepath)
unlabeled_groudtruth_paths.append(user_groudtruth_filepath)

labeled_feature_paths = None
labeled_label_paths = None
if online_benchmark.labeled_data_path is not None:
labeled_data_localpath = os.path.join(save_folder, "labeled_data.zip")
GetData().download_file(online_benchmark.labeled_data_path, labeled_data_localpath)
labeled_feature_paths = []
labeled_label_paths = []

with zipfile.ZipFile(labeled_data_localpath, "r") as z_file:
labeled_data_dirpath = os.path.join(save_folder, "labeled_data")
z_file.extractall(labeled_data_dirpath)
for user_id in range(online_benchmark.user_num):
user_feature_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_feature.pkl"))
user_groudtruth_filepath = os.path.isfile(os.path.join(labeled_data_dirpath, f"user{user_id}_label.pkl"))
assert os.path.isfile(user_feature_filepath), f"user {user_id} labeled feature is not valid!"
assert os.path.isfile(user_groudtruth_filepath), f"user {user_id} labeled label is not valid!"
labeled_feature_paths.append(user_feature_filepath)
labeled_label_paths.append(user_groudtruth_filepath)
extra_zip_localpath = None
if online_benchmark.extra_info_path is not None:
extra_zip_localpath = os.path.join(save_folder, os.path.basename(online_benchmark.extra_info_path))
GetData().download_file(online_benchmark.extra_info_path, extra_zip_localpath)
return list(self.benchmark_configs.keys())

def _check_cache_data_valid(self, benchmark_config: BenchmarkConfig, data_type: str) -> bool:
"""Check if the cache data is valid

Parameters
----------
benchmark_config : BenchmarkConfig
benchmark config
data_type : str
"test" for test data or "train" for train data

Returns
-------
bool
A flag indicating if the cache data is valid
"""
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data")
if os.path.exists(cache_folder):
for user_id in range(benchmark_config.user_num):
X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl")
y_path = os.path.join(cache_folder, f"user{user_id}_X.pkl")
if not os.path.isfile(X_path) or not os.path.isfile(y_path):
return False
return True
else:
return False

def _download_data(self, download_path: str, save_path: str):
"""Download data from backend

Parameters
----------
download_path : str
data path for download in backend
save_path : str
local cache path for saving data
"""
with tempfile.TemporaryDirectory(prefix="learnware_benchmark_") as tempdir:
test_data_zippath = os.path.join(tempdir, "benchmark_data.zip")
GetData().download_file(download_path, test_data_zippath)

os.makedirs(save_path, exist_ok=True)
with zipfile.ZipFile(test_data_zippath, "r") as z_file:
z_file.extractall(save_path)

def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]):
"""Load data from local cache path

Parameters
----------
benchmark_config : BenchmarkConfig
benchmark config
data_type : str
"test" for test data or "train" for train data
"""
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data")
if not self._check_cache_data_valid(benchmark_config, data_type):
download_path = getattr(benchmark_config, f"{data_type}_data_path", None)
self._download_data(download_path, cache_folder)

X_paths, y_paths = [], []
for user_id in range(benchmark_config.user_num):
user_X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl")
user_y_path = os.path.join(cache_folder, f"user{user_id}_y.pkl")
assert os.path.isfile(user_X_path), f"user {user_id} {data_type}_X is not valid!"
assert os.path.isfile(user_y_path), f"user {user_id} {data_type}_y is not valid!"
X_paths.append(user_X_path)
y_paths.append(user_y_path)

def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]):
if isinstance(benchmark_config, str):
benchmark_config = self.benchmark_configs[benchmark_config]

# Load test data
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test")

# Load train data
train_X_paths, train_y_paths = None, None
if benchmark_config.train_data_path is not None:
train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train")

# Load extra info
extra_info_path = None
if benchmark_config.extra_info_path is not None:
extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info")
if not os.path.exists(extra_info_path):
self._download_data(benchmark_config.extra_info_path, extra_info_path)

return Benchmark(
learnware_ids=online_benchmark.learnware_ids,
user_num=online_benchmark.user_num,
unlabeled_feature_paths=unlabeled_feature_paths,
unlabeled_groudtruths_paths=unlabeled_groudtruth_paths,
labeled_feature_paths=labeled_feature_paths,
labeled_label_paths=labeled_label_paths,
extra_info_path=extra_zip_localpath,
learnware_ids=benchmark_config.learnware_ids,
user_num=benchmark_config.user_num,
test_X_paths=test_X_paths,
test_y_paths=test_y_paths,
train_X_paths=train_X_paths,
train_y_paths=train_y_paths,
extra_info_path=extra_info_path,
)
def cleanup(self):
for tempdir in self.tempdir_list:
tempdir.cleanup()

+ 14
- 11
learnware/tests/benchmarks/config.py View File

@@ -1,21 +1,24 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Optional, List
from ...learnware import Learnware

@dataclass
class OnlineBenchmark:
class BenchmarkConfig:
name: str
learnware_ids: List[str]
user_num: int
unlabeled_data_path: str
labeled_data_path: Optional[str] = None
test_data_path: str
train_data_path: Optional[str] = None
extra_info_path: Optional[str] = None

online_benchmarks = {
"example": OnlineBenchmark(

benchmark_configs = {
"example": BenchmarkConfig(
name="example",
learnware_ids=["00001951", "00001980", "00001987"],
user_num=3,
unlabeled_data_path="example_path1",
labeled_data_path="example_path2",
extra_info_path="example_path3"
test_data_path="example_path1",
train_data_path="example_path2",
extra_info_path="example_path3",
)
}
}

Loading…
Cancel
Save