Browse Source

[MNT] modify details in benchmark class

tags/v0.3.2
Gene 2 years ago
parent
commit
21392edfd7
3 changed files with 26 additions and 15 deletions
  1. +21
    -9
      learnware/tests/benchmarks/__init__.py
  2. +2
    -2
      learnware/tests/benchmarks/config.py
  3. +3
    -4
      learnware/tests/data.py

+ 21
- 9
learnware/tests/benchmarks/__init__.py View File

@@ -12,16 +12,18 @@ from ...config import C

@dataclass
class Benchmark:
learnware_ids: List[str]
name: str
user_num: int
learnware_ids: List[str]
test_X_paths: List[str]
test_y_paths: List[str]
train_X_paths: Optional[List[str]] = None
train_y_paths: Optional[List[str]] = None
extra_info_path: Optional[str] = None

def get_test_data(self, user_ids: Union[str, List[str]]):
if isinstance(user_ids, str):
def get_test_data(self, user_ids: Union[int, List[int]]):
raw_user_ids = user_ids
if isinstance(user_ids, int):
user_ids = [user_ids]

ret = []
@@ -34,13 +36,17 @@ class Benchmark:

ret.append((test_X, test_y))

return ret
if isinstance(raw_user_ids, int):
return ret[0]
else:
return ret

def get_train_data(self, user_ids):
def get_train_data(self, user_ids: Union[int, List[int]]):
if self.train_X_paths is None or self.train_y_paths is None:
return None

if isinstance(user_ids, str):
raw_user_ids = user_ids
if isinstance(user_ids, int):
user_ids = [user_ids]

ret = []
@@ -53,7 +59,10 @@ class Benchmark:

ret.append((train_X, train_y))

return ret
if isinstance(raw_user_ids, int):
return ret[0]
else:
return ret


class LearnwareBenchmark:
@@ -107,7 +116,7 @@ class LearnwareBenchmark:
with zipfile.ZipFile(test_data_zippath, "r") as z_file:
z_file.extractall(save_path)

def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]):
def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple[List[str], List[str]]:
"""Load data from local cache path

Parameters
@@ -131,6 +140,8 @@ class LearnwareBenchmark:
X_paths.append(user_X_path)
y_paths.append(user_y_path)

return X_paths, y_paths

def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]):
if isinstance(benchmark_config, str):
benchmark_config = self.benchmark_configs[benchmark_config]
@@ -151,8 +162,9 @@ class LearnwareBenchmark:
self._download_data(benchmark_config.extra_info_path, extra_info_path)

return Benchmark(
learnware_ids=benchmark_config.learnware_ids,
name=benchmark_config.name,
user_num=benchmark_config.user_num,
learnware_ids=benchmark_config.learnware_ids,
test_X_paths=test_X_paths,
test_y_paths=test_y_paths,
train_X_paths=train_X_paths,


+ 2
- 2
learnware/tests/benchmarks/config.py View File

@@ -5,8 +5,8 @@ from typing import Optional, List
@dataclass
class BenchmarkConfig:
name: str
learnware_ids: List[str]
user_num: int
learnware_ids: List[str]
test_data_path: str
train_data_path: Optional[str] = None
extra_info_path: Optional[str] = None
@@ -15,8 +15,8 @@ class BenchmarkConfig:
benchmark_configs = {
"example": BenchmarkConfig(
name="example",
learnware_ids=["00001951", "00001980", "00001987"],
user_num=3,
learnware_ids=["00001951", "00001980", "00001987"],
test_data_path="example_path1",
train_data_path="example_path2",
extra_info_path="example_path3",


+ 3
- 4
learnware/tests/data.py View File

@@ -1,4 +1,3 @@

import json
import requests
from tqdm import tqdm
@@ -18,12 +17,12 @@ class GetData:
self.chunk_size = chunk_size

def download_file(self, file_path: str, save_path: str):
url = f"{self.host}/engine/download"
url = f"{self.host}/datasets/download_datasets"

response = requests.get(
url,
params={
"file_path": file_path,
"dataset": file_path,
},
stream=True,
)
@@ -37,4 +36,4 @@ class GetData:
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=self.chunk_size):
f.write(chunk)
bar.update(1)
bar.update(1)

Loading…
Cancel
Save