|
|
|
@@ -12,16 +12,18 @@ from ...config import C |
|
|
|
|
|
|
|
@dataclass |
|
|
|
class Benchmark: |
|
|
|
learnware_ids: List[str] |
|
|
|
name: str |
|
|
|
user_num: int |
|
|
|
learnware_ids: List[str] |
|
|
|
test_X_paths: List[str] |
|
|
|
test_y_paths: List[str] |
|
|
|
train_X_paths: Optional[List[str]] = None |
|
|
|
train_y_paths: Optional[List[str]] = None |
|
|
|
extra_info_path: Optional[str] = None |
|
|
|
|
|
|
|
def get_test_data(self, user_ids: Union[str, List[str]]): |
|
|
|
if isinstance(user_ids, str): |
|
|
|
def get_test_data(self, user_ids: Union[int, List[int]]): |
|
|
|
raw_user_ids = user_ids |
|
|
|
if isinstance(user_ids, int): |
|
|
|
user_ids = [user_ids] |
|
|
|
|
|
|
|
ret = [] |
|
|
|
@@ -34,13 +36,17 @@ class Benchmark: |
|
|
|
|
|
|
|
ret.append((test_X, test_y)) |
|
|
|
|
|
|
|
return ret |
|
|
|
if isinstance(raw_user_ids, int): |
|
|
|
return ret[0] |
|
|
|
else: |
|
|
|
return ret |
|
|
|
|
|
|
|
def get_train_data(self, user_ids): |
|
|
|
def get_train_data(self, user_ids: Union[int, List[int]]): |
|
|
|
if self.train_X_paths is None or self.train_y_paths is None: |
|
|
|
return None |
|
|
|
|
|
|
|
if isinstance(user_ids, str): |
|
|
|
raw_user_ids = user_ids |
|
|
|
if isinstance(user_ids, int): |
|
|
|
user_ids = [user_ids] |
|
|
|
|
|
|
|
ret = [] |
|
|
|
@@ -53,7 +59,10 @@ class Benchmark: |
|
|
|
|
|
|
|
ret.append((train_X, train_y)) |
|
|
|
|
|
|
|
return ret |
|
|
|
if isinstance(raw_user_ids, int): |
|
|
|
return ret[0] |
|
|
|
else: |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class LearnwareBenchmark: |
|
|
|
@@ -107,7 +116,7 @@ class LearnwareBenchmark: |
|
|
|
with zipfile.ZipFile(test_data_zippath, "r") as z_file: |
|
|
|
z_file.extractall(save_path) |
|
|
|
|
|
|
|
def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]): |
|
|
|
def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple[List[str], List[str]]: |
|
|
|
"""Load data from local cache path |
|
|
|
|
|
|
|
Parameters |
|
|
|
@@ -131,6 +140,8 @@ class LearnwareBenchmark: |
|
|
|
X_paths.append(user_X_path) |
|
|
|
y_paths.append(user_y_path) |
|
|
|
|
|
|
|
return X_paths, y_paths |
|
|
|
|
|
|
|
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): |
|
|
|
if isinstance(benchmark_config, str): |
|
|
|
benchmark_config = self.benchmark_configs[benchmark_config] |
|
|
|
@@ -151,8 +162,9 @@ class LearnwareBenchmark: |
|
|
|
self._download_data(benchmark_config.extra_info_path, extra_info_path) |
|
|
|
|
|
|
|
return Benchmark( |
|
|
|
learnware_ids=benchmark_config.learnware_ids, |
|
|
|
name=benchmark_config.name, |
|
|
|
user_num=benchmark_config.user_num, |
|
|
|
learnware_ids=benchmark_config.learnware_ids, |
|
|
|
test_X_paths=test_X_paths, |
|
|
|
test_y_paths=test_y_paths, |
|
|
|
train_X_paths=train_X_paths, |
|
|
|
|