From 3a22a48ae50005a8e99d32e66e7a1f4dbd7fc119 Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Thu, 9 Nov 2023 20:36:18 +0800 Subject: [PATCH] [FIX] fix dataset_m5_workflow --- examples/dataset_m5_workflow/example_init.py | 2 +- examples/dataset_m5_workflow/main.py | 31 +++++++++++++------ examples/dataset_pfs_workflow/main.py | 27 +++++++++++----- examples/dataset_text_workflow/main.py | 4 --- learnware/market/easy/checker.py | 1 + learnware/market/easy/searcher.py | 2 +- learnware/specification/regular/table/rkme.py | 29 +++++++++++++---- setup.py | 6 ++-- 8 files changed, 71 insertions(+), 31 deletions(-) diff --git a/examples/dataset_m5_workflow/example_init.py b/examples/dataset_m5_workflow/example_init.py index e0aabdd..eade812 100644 --- a/examples/dataset_m5_workflow/example_init.py +++ b/examples/dataset_m5_workflow/example_init.py @@ -7,7 +7,7 @@ from learnware.model import BaseModel class Model(BaseModel): def __init__(self): - super(Model, self).__init__(input_shape=(82,), output_shape=()) + super(Model, self).__init__(input_shape=(82,), output_shape=(1,)) dir_path = os.path.dirname(os.path.abspath(__file__)) self.model = lgb.Booster(model_file=os.path.join(dir_path, "model.out")) diff --git a/examples/dataset_m5_workflow/main.py b/examples/dataset_m5_workflow/main.py index 7a8d971..2e126e0 100644 --- a/examples/dataset_m5_workflow/main.py +++ b/examples/dataset_m5_workflow/main.py @@ -8,7 +8,7 @@ from shutil import copyfile, rmtree import learnware from learnware.market import instantiate_learnware_market, BaseUserInfo -from learnware.market import database_ops +# from learnware.market import database_ops from learnware.reuse import JobSelectorReuser, AveragingReuser from learnware.specification import generate_rkme_spec from m5 import DataLoader @@ -17,27 +17,40 @@ from learnware.logger import get_module_logger logger = get_module_logger("m5_test", level="INFO") +output_description = { + "Dimension": 1, + "Description": {}, +} + +input_description = { + "Dimension": 82, + "Description": {}, +} + semantic_specs = [ { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": {"Values": ["Classification"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, + "Input": input_description, + "Output": output_description, } ] user_semantic = { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": {"Values": ["Classification"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "", "Type": "String"}, + "Input": input_description, + "Output": output_description, } - class M5DatasetWorkflow: def _init_m5_dataset(self): m5 = DataLoader() @@ -69,8 +82,8 @@ class M5DatasetWorkflow: easy_market.add_learnware(zip_path, semantic_spec) print("Total Item:", len(easy_market)) - curr_inds = easy_market._get_ids() - print("Available ids:", curr_inds) + # curr_inds = easy_market._get_ids() + # print("Available ids:", curr_inds) def prepare_learnware(self, regenerate_flag=False): if regenerate_flag: @@ -171,7 +184,7 @@ class M5DatasetWorkflow: job_selector_score = m5.score(test_y, job_selector_predict_y) print(f"mixture reuse loss (job selector): {job_selector_score}") - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") + reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=test_x) ensemble_score = m5.score(test_y, ensemble_predict_y) print(f"mixture reuse loss (ensemble): {ensemble_score}\n") diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index d66074e..9ce02f7 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -15,25 +15,38 @@ from learnware.logger import get_module_logger logger = get_module_logger("pfs_test", level="INFO") +output_description = { + "Dimension": 1, + "Description": {}, +} + +input_description = { + "Dimension": 31, + "Description": {}, +} semantic_specs = [ { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": {"Values": ["Classification"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, + "Input": input_description, + "Output": output_description, } ] user_semantic = { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": {"Values": ["Classification"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, - "Name": {"Values": "", "Type": "String"}, + "Name": {"Values": "learnware_1", "Type": "String"}, + "Input": input_description, + "Output": output_description, } @@ -66,8 +79,8 @@ class PFSDatasetWorkflow: easy_market.add_learnware(zip_path, semantic_spec) print("Total Item:", len(easy_market)) - curr_inds = easy_market._get_ids() - print("Available ids:", curr_inds) + # curr_inds = easy_market._get_ids() + # print("Available ids:", curr_inds) def prepare_learnware(self, regenerate_flag=False): if regenerate_flag: diff --git a/examples/dataset_text_workflow/main.py b/examples/dataset_text_workflow/main.py index bf20091..179bf53 100644 --- a/examples/dataset_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -196,10 +196,6 @@ def test_search(gamma=0.1, load_market=True): ensemble_score_list.append(ensemble_score) print(f"mixture reuse accuracy (ensemble): {ensemble_score}") - select_list.append(acc_list[0]) - avg_list.append(np.mean(acc_list)) - improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) - # test reuse (ensemblePruning) reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list) pruning_predict_y = reuse_pruning.predict(user_data=user_data) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 5e455a7..a988157 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -120,6 +120,7 @@ class EasyStatChecker(BaseChecker): raise ValueError(f"not supported spec type for spec_type = {spec_type}") # Check output + outputs = learnware.predict(inputs) try: outputs = learnware.predict(inputs) except Exception: diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index cc51b30..5fb5671 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -277,7 +277,7 @@ class EasyStatSearcher(BaseSearcher): # beta must be nonnegative weight, obj = rkme_solve_qp(K, C) - weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device) + weight = weight.double().to(user_rkme.device) score = user_rkme.inner_prod(user_rkme) + 2 * obj return weight.detach().cpu().numpy().reshape(-1), score diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 3fcd985..8f147b2 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -13,6 +13,7 @@ from qpsolvers import solve_qp, Problem, solve_problem from collections import Counter from typing import Tuple, Any, List, Union, Dict import scipy +from sklearn.cluster import MiniBatchKMeans try: import faiss @@ -27,10 +28,10 @@ from ....logger import get_module_logger logger = get_module_logger("rkme") -if not _FAISS_INSTALLED: - logger.warning( - "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" - ) +# if not _FAISS_INSTALLED: +# logger.warning( +# "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" +# ) class RKMETableSpecification(RegularStatsSpecification): @@ -127,8 +128,8 @@ class RKMETableSpecification(RegularStatsSpecification): self.beta = torch.from_numpy(self.beta).double().to(self.device) return - # Initialize Z by clustering, utiliing faiss to speed up the process. - self._init_z_by_faiss(X, K) + # Initialize Z by clustering, utiliing kmeans or faiss to speed up the process. + self._init_z_by_kmeans(X, K) self._update_beta(X, nonnegative_beta) # Alternating optimize Z and beta @@ -156,6 +157,22 @@ class RKMETableSpecification(RegularStatsSpecification): center = torch.from_numpy(kmeans.centroids).double() self.z = center + def _init_z_by_kmeans(self, X: Union[np.ndarray, torch.tensor], K: int): + """Intialize Z by kmeans clustering. + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in np.ndarray format or torch.tensor format. + K : int + Size of the construced reduced set. + """ + X = X.astype("float32") + kmeans = MiniBatchKMeans(n_clusters=K, max_iter=100, verbose=False, n_init="auto") + kmeans.fit(X) + center = torch.from_numpy(kmeans.cluster_centers_).double() + self.z = center + def _update_beta(self, X: Any, nonnegative_beta: bool = True): """Fix Z and update beta using its closed-form solution. diff --git a/setup.py b/setup.py index d8dde7a..a8d6e5d 100644 --- a/setup.py +++ b/setup.py @@ -75,11 +75,11 @@ REQUIRED = [ "langdetect>=1.0.9", "huggingface-hub<0.18", "portalocker>=2.0.0", - "qpsolvers[clarabel]>=4.0.1" + "qpsolvers[clarabel]>=4.0.1", ] -if get_platform() != MACOS: - REQUIRED.append("faiss-cpu>=1.7.1") +# if get_platform() != MACOS: +# REQUIRED.append("faiss-cpu>=1.7.1") here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, "README.md"), encoding="utf-8") as f: