diff --git a/docs/about/dev.rst b/docs/about/dev.rst index 8586d10..309d1dc 100644 --- a/docs/about/dev.rst +++ b/docs/about/dev.rst @@ -1,8 +1,7 @@ .. _dev: - -============= -Code Standard -============= +================ +For Developer +================ Docstring ============ diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 5ac7749..c91981c 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -13,7 +13,7 @@ from learnware.learnware import Learnware import time from learnware.market import instantiate_learnware_market, BaseUserInfo -from learnware.market import database_ops +from learnware.market.easy import database_ops from learnware.learnware import Learnware import learnware.specification as specification from learnware.logger import get_module_logger @@ -168,15 +168,14 @@ def test_search(gamma=0.1, load_market=True): user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % i) - sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( - user_info - ) + search_result = image_market.search_learnware(user_info) + single_result = search_result.get_single_results() acc_list = [] - for idx, (score, learnware) in enumerate(zip(sorted_score_list[:5], single_learnware_list[:5])): - pred_y = learnware.predict(user_data) + for idx, single_item in enumerate(single_result[:5]): + pred_y = single_item.learnware.predict(user_data) acc = eval_prediction(pred_y, user_label) acc_list.append(acc) - logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) + logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, single_item.score, single_item.learnware.id, acc)) # test reuse (job selector) # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) @@ -186,6 +185,7 @@ def test_search(gamma=0.1, load_market=True): # print(f"mixture reuse loss: {reuse_score}") # test reuse (ensemble) + single_learnware_list = [single_item.learnware for single_item in single_result] reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_score = eval_prediction(ensemble_predict_y, user_label) diff --git a/examples/dataset_m5_workflow/main.py b/examples/dataset_m5_workflow/main.py index 763669d..60ee439 100644 --- a/examples/dataset_m5_workflow/main.py +++ b/examples/dataset_m5_workflow/main.py @@ -155,29 +155,28 @@ class M5DatasetWorkflow: user_spec.save(user_spec_path) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = easy_market.search_learnware(user_info) - + search_result = easy_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + print(f"search result of user{idx}:") print( - f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" + f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" ) loss_list = [] - for score, learnware in zip(sorted_score_list, single_learnware_list): - pred_y = learnware.predict(test_x) + for single_item in single_result: + pred_y = single_item.learnware.predict(test_x) loss_list.append(m5.score(test_y, pred_y)) print( - f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, loss: {loss_list[0]}" + f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, loss: {loss_list[0]}" ) - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") - if not mixture_learnware_list: - mixture_learnware_list = [single_learnware_list[0]] + if len(multiple_result) > 0: + mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) + print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") + mixture_learnware_list = multiple_result[0].learnwares + else: + mixture_learnware_list = [single_result[0].learnware] reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False) job_selector_predict_y = reuse_job_selector.predict(user_data=test_x) diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index 784c383..74c4da5 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -152,29 +152,28 @@ class PFSDatasetWorkflow: user_spec.save(user_spec_path) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = easy_market.search_learnware(user_info) - + search_result = easy_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + print(f"search result of user{idx}:") print( - f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" + f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" ) loss_list = [] - for score, learnware in zip(sorted_score_list, single_learnware_list): - pred_y = learnware.predict(test_x) + for single_item in single_result: + pred_y = single_item.learnware.predict(test_x) loss_list.append(pfs.score(test_y, pred_y)) print( - f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, loss: {loss_list[0]}, random: {np.mean(loss_list)}" + f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, loss: {loss_list[0]}, random: {np.mean(loss_list)}" ) - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") - if not mixture_learnware_list: - mixture_learnware_list = [single_learnware_list[0]] + if len(multiple_result) > 0: + mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) + print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") + mixture_learnware_list = multiple_result[0].learnwares + else: + mixture_learnware_list = [single_result[0].learnware] reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False) job_selector_predict_y = reuse_job_selector.predict(user_data=test_x) diff --git a/examples/dataset_text_workflow/main.py b/examples/dataset_text_workflow/main.py index 72de04e..c5715e7 100644 --- a/examples/dataset_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -199,31 +199,34 @@ class TextDatasetWorkflow: user_stat_spec.generate_stat_spec_from_data(X=user_data) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % (i)) - sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( - user_info - ) - + + search_result = text_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + print(f"search result of user{i}:") print( - f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" + f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" ) - l = len(sorted_score_list) + l = len(single_result) acc_list = [] for idx in range(l): - learnware = single_learnware_list[idx] - score = sorted_score_list[idx] + learnware = single_result[idx].learnware + score = single_result[idx].score pred_y = learnware.predict(user_data) acc = eval_prediction(pred_y, user_label) acc_list.append(acc) print( - f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, acc: {acc_list[0]}" + f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {acc_list[0]}" ) - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") - if not mixture_learnware_list: - mixture_learnware_list = [single_learnware_list[0]] + if len(multiple_result) > 0: + mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) + print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") + mixture_learnware_list = multiple_result[0].learnwares + else: + mixture_learnware_list = [single_result[0].learnware] # test reuse (job selector) reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) diff --git a/examples/workflow_by_code/learnware_example/README.md b/examples/workflow_by_code/learnware_example/README.md deleted file mode 100644 index 51aac5a..0000000 --- a/examples/workflow_by_code/learnware_example/README.md +++ /dev/null @@ -1,10 +0,0 @@ -## How to Generate Environment Yaml - -* create env config for conda: -```shell -conda env export | grep -v "^prefix: " > environment.yml -``` -* recover env from config -``` -conda env create -f environment.yml -``` \ No newline at end of file diff --git a/examples/workflow_by_code/learnware_example/environment.yml b/examples/workflow_by_code/learnware_example/environment.yml deleted file mode 100644 index 2923bdb..0000000 --- a/examples/workflow_by_code/learnware_example/environment.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: learnware_example_env -channels: - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - ca-certificates=2023.01.10=h06a4308_0 - - ld_impl_linux-64=2.38=h1181459_1 - - libffi=3.4.2=h6a678d5_6 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libstdcxx-ng=11.2.0=h1234567_1 - - ncurses=6.4=h6a678d5_0 - - openssl=1.1.1t=h7f8727e_0 - - pip=23.0.1=py38h06a4308_0 - - python=3.8.16=h7a1cb2a_3 - - readline=8.2=h5eee18b_0 - - setuptools=66.0.0=py38h06a4308_0 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - wheel=0.38.4=py38h06a4308_0 - - xz=5.2.10=h5eee18b_1 - - zlib=1.2.13=h5eee18b_0 - - pip: - - joblib==1.2.0 - - learnware==0.0.1.99 - - numpy==1.19.5 diff --git a/examples/workflow_by_code/learnware_example/example.yaml b/examples/workflow_by_code/learnware_example/example.yaml deleted file mode 100644 index 32aa52e..0000000 --- a/examples/workflow_by_code/learnware_example/example.yaml +++ /dev/null @@ -1,8 +0,0 @@ -model: - class_name: SVM - kwargs: {} -stat_specifications: - - module_path: learnware.specification - class_name: RKMETableSpecification - file_name: svm.json - kwargs: {} \ No newline at end of file diff --git a/examples/workflow_by_code/learnware_example/example_init.py b/examples/workflow_by_code/learnware_example/example_init.py deleted file mode 100644 index 47d3708..0000000 --- a/examples/workflow_by_code/learnware_example/example_init.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import joblib -import numpy as np -from learnware.model import BaseModel - - -class SVM(BaseModel): - def __init__(self): - super(SVM, self).__init__(input_shape=(64,), output_shape=(10,)) - dir_path = os.path.dirname(os.path.abspath(__file__)) - self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) - - def fit(self, X: np.ndarray, y: np.ndarray): - pass - - def predict(self, X: np.ndarray) -> np.ndarray: - return self.model.predict_proba(X) - - def finetune(self, X: np.ndarray, y: np.ndarray): - pass diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py deleted file mode 100644 index 8a08a61..0000000 --- a/examples/workflow_by_code/main.py +++ /dev/null @@ -1,197 +0,0 @@ -import os -import fire -import copy -import joblib -import zipfile -import numpy as np -from sklearn import svm -from sklearn.datasets import load_digits -from sklearn.model_selection import train_test_split -from shutil import copyfile, rmtree - -import learnware -from learnware.market import instantiate_learnware_market, BaseUserInfo -from learnware.reuse import JobSelectorReuser, AveragingReuser -from learnware.specification import generate_rkme_table_spec, RKMETableSpecification - -curr_root = os.path.dirname(os.path.abspath(__file__)) - -user_semantic = { - "Data": {"Values": ["Table"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, - "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, - "Scenario": {"Values": ["Education"], "Type": "Tag"}, - "Description": {"Values": "", "Type": "String"}, - "Name": {"Values": "", "Type": "String"}, -} - - -class LearnwareMarketWorkflow: - def _init_learnware_market(self): - """initialize learnware market""" - learnware.init() - np.random.seed(2023) - easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) - return easy_market - - def prepare_learnware_randomly(self, learnware_num=5): - self.zip_path_list = [] - X, y = load_digits(return_X_y=True) - - for i in range(learnware_num): - dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i)) - os.makedirs(dir_path, exist_ok=True) - - print("Preparing Learnware: %d" % (i)) - - data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) - clf = svm.SVC(kernel="linear", probability=True) - clf.fit(data_X, data_y) - - joblib.dump(clf, os.path.join(dir_path, "svm.pkl")) - - spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) - spec.save(os.path.join(dir_path, "svm.json")) - - init_file = os.path.join(dir_path, "__init__.py") - copyfile( - os.path.join(curr_root, "learnware_example/example_init.py"), init_file - ) # cp example_init.py init_file - - yaml_file = os.path.join(dir_path, "learnware.yaml") - copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file - - zip_file = dir_path + ".zip" - # zip -q -r -j zip_file dir_path - with zipfile.ZipFile(zip_file, "w") as zip_obj: - for foldername, subfolders, filenames in os.walk(dir_path): - for filename in filenames: - file_path = os.path.join(foldername, filename) - zip_info = zipfile.ZipInfo(filename) - zip_info.compress_type = zipfile.ZIP_STORED - with open(file_path, "rb") as file: - zip_obj.writestr(zip_info, file.read()) - - rmtree(dir_path) # rm -r dir_path - - self.zip_path_list.append(zip_file) - - def test_upload_delete_learnware(self, learnware_num=5, delete=False): - easy_market = self._init_learnware_market() - self.prepare_learnware_randomly(learnware_num) - - print("Total Item:", len(easy_market)) - - for idx, zip_path in enumerate(self.zip_path_list): - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) - semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) - easy_market.add_learnware(zip_path, semantic_spec) - - print("Total Item:", len(easy_market)) - curr_inds = easy_market.get_learnware_ids() - print("Available ids After Uploading Learnwares:", curr_inds) - - if delete: - for learnware_id in curr_inds: - easy_market.delete_learnware(learnware_id) - curr_inds = easy_market.get_learnware_ids() - print("Available ids After Deleting Learnwares:", curr_inds) - - return easy_market - - def test_search_semantics(self, learnware_num=5): - easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) - print("Total Item:", len(easy_market)) - - test_folder = os.path.join(curr_root, "test_semantics") - - # unzip -o -q zip_path -d unzip_dir - if os.path.exists(test_folder): - rmtree(test_folder) - os.makedirs(test_folder, exist_ok=True) - - with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: - zip_obj.extractall(path=test_folder) - - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" - semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" - - user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) - - print("User info:", user_info.get_semantic_spec()) - print(f"Search result:") - for learnware in single_learnware_list: - print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) - - rmtree(test_folder) # rm -r test_folder - - def test_stat_search(self, learnware_num=5): - easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) - print("Total Item:", len(easy_market)) - - test_folder = os.path.join(curr_root, "test_stat") - - for idx, zip_path in enumerate(self.zip_path_list): - unzip_dir = os.path.join(test_folder, f"{idx}") - - # unzip -o -q zip_path -d unzip_dir - if os.path.exists(unzip_dir): - rmtree(unzip_dir) - os.makedirs(unzip_dir, exist_ok=True) - with zipfile.ZipFile(zip_path, "r") as zip_obj: - zip_obj.extractall(path=unzip_dir) - - user_spec = RKMETableSpecification() - user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = easy_market.search_learnware(user_info) - - print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}\n") - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_learnware: {mixture_id}\n") - - rmtree(test_folder) # rm -r test_folder - - def test_learnware_reuse(self, learnware_num=5): - easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) - print("Total Item:", len(easy_market)) - - X, y = load_digits(return_X_y=True) - _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) - - stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) - - _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) - - # print("Mixture Learnware:", mixture_learnware_list) - - # Based on user information, the learnware market returns a list of learnwares (learnware_list) - # Use jobselector reuser to reuse the searched learnwares to make prediction - reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) - job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) - - # Use averaging ensemble reuser to reuse the searched learnwares to make prediction - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list) - ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) - - print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y)) - print("Averaging Selector Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y)) - - -if __name__ == "__main__": - fire.Fire(LearnwareMarketWorkflow) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index b396add..f3cbe61 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -18,6 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker from ..logger import get_module_logger from ..specification import Specification from ..learnware import get_learnware_from_dirpath +from ..market import BaseUserInfo from ..tests import get_semantic_specification CHUNK_SIZE = 1024 * 1024 @@ -204,10 +205,10 @@ class LearnwareClient: return learnware_list @require_login - def search_learnware(self, specification: Specification, page_size=10, page_index=0): + def search_learnware(self, user_info: BaseUserInfo, page_size=10, page_index=0): url = f"{self.host}/engine/search_learnware" - stat_spec = specification.get_stat_spec() + stat_spec = user_info.stat_info if len(stat_spec) > 1: raise Exception("statistical specification must have only one key.") @@ -222,10 +223,7 @@ class LearnwareClient: stat_spec.save(ftemp.name) with open(ftemp.name, "r") as fin: - semantic_specification = specification.get_semantic_spec() - if semantic_specification is None: - semantic_specification = {} - + semantic_specification = user_info.get_semantic_spec() if stat_spec is None: files = None else: @@ -235,7 +233,7 @@ class LearnwareClient: url, files=files, data={ - "semantic_specification": json.dumps(specification.get_semantic_spec()), + "semantic_specification": json.dumps(semantic_specification), "limit": page_size, "page": page_index, }, @@ -249,13 +247,25 @@ class LearnwareClient: for learnware in result["data"]["learnware_list_single"]: returns.append( - { + { + "type": "single", "learnware_id": learnware["learnware_id"], "semantic_specification": learnware["semantic_specification"], "matching": learnware["matching"], } ) - + if len(result["data"]["learnware_list_multi"]) > 0: + multiple_learnware = { + "type": "multiple", + "learnware_ids": [], + "semantic_specifications": [], + "matching": result["data"]["learnware_list_multi"][0]["matching"] + } + for learnware in result["data"]["learnware_list_multi"]: + multiple_learnware["learnware_ids"].append(learnware["learnware_id"]) + multiple_learnware["semantic_specifications"].append(learnware["semantic_specification"]) + + returns.append(multiple_learnware) return returns @require_login diff --git a/learnware/market/base.py b/learnware/market/base.py index ddf3d9a..78d06e6 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -3,11 +3,12 @@ from __future__ import annotations import traceback import zipfile import tempfile -from typing import Tuple, Any, List, Union +from typing import Tuple, Any, List, Union, Dict, Optional +from dataclasses import dataclass from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger -logger = get_module_logger("market_base", "INFO") +logger = get_module_logger("market_base") class BaseUserInfo: @@ -42,6 +43,9 @@ class BaseUserInfo: def get_stat_info(self, name: str): return self.stat_info.get(name, None) + def update_semantic_spec(self, semantic_spec: dict): + self.semantic_spec = semantic_spec + def update_stat_info(self, name: str, item: Any): """Update stat_info by market @@ -55,6 +59,33 @@ class BaseUserInfo: self.stat_info[name] = item +@dataclass +class SingleSearchItem: + learnware: Learnware + score: Optional[float] = None + +@dataclass +class MultipleSearchItem: + learnwares: List[Learnware] + score: float + +class SearchResults: + def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): + self.update_single_results([] if single_results is None else single_results) + self.update_multiple_results([] if multiple_results is None else multiple_results) + + def get_single_results(self) -> List[SingleSearchItem]: + return self.single_results + + def get_multiple_results(self) -> List[MultipleSearchItem]: + return self.multiple_results + + def update_single_results(self, single_results: List[SingleSearchItem]): + self.single_results = single_results + + def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): + self.multiple_results = multiple_results + class LearnwareMarket: """Base interface for market, it provide the interface of search/add/detele/update learnwares""" @@ -150,7 +181,7 @@ class LearnwareMarket: def search_learnware( self, user_info: BaseUserInfo, check_status: int = None, **kwargs - ) -> Tuple[Any, List[Learnware]]: + ) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status Parameters @@ -163,7 +194,7 @@ class LearnwareMarket: Returns ------- - Tuple[Any, List[Learnware]] + SearchResults Search results """ return self.learnware_searcher(user_info, check_status, **kwargs) @@ -450,7 +481,7 @@ class BaseSearcher: def reset(self, organizer: BaseOrganizer, **kwargs): self.learnware_organizer = organizer - def __call__(self, user_info: BaseUserInfo, check_status: int = None): + def __call__(self, user_info: BaseUserInfo, check_status: int = None) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status Parameters diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index 5820ee9..99e1e2d 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -2,11 +2,11 @@ import math import torch import numpy as np from rapidfuzz import fuzz -from typing import Tuple, List, Union +from typing import Tuple, List, Union, Optional from .organizer import EasyOrganizer from ..utils import parse_specification_type -from ..base import BaseUserInfo, BaseSearcher +from ..base import BaseUserInfo, BaseSearcher, SearchResults, SingleSearchItem, MultipleSearchItem from ...learnware import Learnware from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification, rkme_solve_qp from ...logger import get_module_logger @@ -57,7 +57,7 @@ class EasyExactSemanticSearcher(BaseSearcher): return True - def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: + def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> SearchResults: match_learnwares = [] for learnware in learnware_list: learnware_semantic_spec = learnware.get_specification().get_semantic_spec() @@ -65,8 +65,7 @@ class EasyExactSemanticSearcher(BaseSearcher): if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec): match_learnwares.append(learnware) logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) - return match_learnwares - + return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares]) class EasyFuzzSemanticSearcher(BaseSearcher): def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: @@ -111,7 +110,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): def __call__( self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0 - ) -> List[Learnware]: + ) -> SearchResults: """Search learnware by fuzzy matching of semantic spec Parameters @@ -182,7 +181,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): final_result = matched_learnware_tag logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) - return final_result + return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result]) class EasyStatSearcher(BaseSearcher): @@ -328,7 +327,7 @@ class EasyStatSearcher(BaseSearcher): user_rkme: RKMETableSpecification, max_search_num: int, weight_cutoff: float = 0.98, - ) -> Tuple[float, List[float], List[Learnware]]: + ) -> Tuple[Optional[float], List[float], List[Learnware]]: """Select learnwares based on a total mixture ratio, then recalculate their mixture weights Parameters @@ -351,7 +350,7 @@ class EasyStatSearcher(BaseSearcher): """ learnware_num = len(learnware_list) if learnware_num == 0: - return [], [] + return None, [], [] if learnware_num < max_search_num: logger.warning("Available Learnware num less than search_num!") max_search_num = learnware_num @@ -370,7 +369,7 @@ class EasyStatSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] - mixture_weight = [1] + mixture_weight = [1.0] mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type)) else: if len(mixture_list) > max_search_num: @@ -455,7 +454,7 @@ class EasyStatSearcher(BaseSearcher): user_rkme: RKMETableSpecification, max_search_num: int, decay_rate: float = 0.95, - ) -> Tuple[float, List[float], List[Learnware]]: + ) -> Tuple[Optional[float], List[float], List[Learnware]]: """Greedily match learnwares such that their mixture become closer and closer to user's rkme Parameters @@ -484,7 +483,7 @@ class EasyStatSearcher(BaseSearcher): max_search_num = learnware_num flag_list = [0 for _ in range(learnware_num)] - mixture_list, weight_list, mmd_dist = [], None, None + mixture_list, weight_list, mmd_dist = [], [], None intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) for k in range(max_search_num): @@ -543,10 +542,10 @@ class EasyStatSearcher(BaseSearcher): the second is the list of Learnware both lists are sorted by mmd dist """ - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] + rkme_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] mmd_dist_list = [] - for RKME in RKME_list: - mmd_dist = RKME.dist(user_rkme) + for rkme in rkme_list: + mmd_dist = rkme.dist(user_rkme) mmd_dist_list.append(mmd_dist) sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) @@ -561,7 +560,7 @@ class EasyStatSearcher(BaseSearcher): user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy", - ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: + ) -> SearchResults: self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info) if self.stat_spec_type is None: raise KeyError("No supported stat specification is given in the user info") @@ -572,7 +571,7 @@ class EasyStatSearcher(BaseSearcher): sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) if len(single_learnware_list) == 0: - return [], [], None, [] + return SearchResults() processed_learnware_list = single_learnware_list[: max_search_num * max_search_num] if sorted_dist_list[0] > 0 and search_method == "auto": @@ -622,7 +621,16 @@ class EasyStatSearcher(BaseSearcher): mixture_score = min(1, mixture_score * ratio) if mixture_score is not None else None logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") - return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list + search_results = SearchResults() + + search_results.update_single_results( + [SingleSearchItem(learnware=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] + ) + if mixture_score is not None and len(mixture_learnware_list) > 0: + search_results.update_multiple_results( + [MultipleSearchItem(learnwares=mixture_learnware_list, score=mixture_score)] + ) + return search_results class EasySearcher(BaseSearcher): @@ -638,7 +646,7 @@ class EasySearcher(BaseSearcher): def __call__( self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" - ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: + ) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status Parameters @@ -660,12 +668,13 @@ class EasySearcher(BaseSearcher): the fourth is the list of Learnware (mixture), the size is search_num """ learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) - learnware_list = self.semantic_searcher(learnware_list, user_info) + semantic_search_result = self.semantic_searcher(learnware_list, user_info) + learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] if len(learnware_list) == 0: - return [], [], 0.0, [] + return SearchResults() if parse_specification_type(stat_specs=user_info.stat_info) is not None: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: - return None, learnware_list, 0.0, None + return semantic_search_result diff --git a/learnware/market/heterogeneous/searcher.py b/learnware/market/heterogeneous/searcher.py index 7a79004..f261703 100644 --- a/learnware/market/heterogeneous/searcher.py +++ b/learnware/market/heterogeneous/searcher.py @@ -2,7 +2,7 @@ import traceback from typing import Tuple, List from .utils import is_hetero -from ..base import BaseUserInfo +from ..base import BaseUserInfo, SearchResults from ..easy import EasySearcher from ..utils import parse_specification_type from ...learnware import Learnware @@ -15,7 +15,7 @@ logger = get_module_logger("hetero_searcher") class HeteroSearcher(EasySearcher): def __call__( self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" - ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: + ) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status. Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. @@ -38,10 +38,11 @@ class HeteroSearcher(EasySearcher): the fourth is the list of Learnware (mixture), the size is search_num """ learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) - learnware_list = self.semantic_searcher(learnware_list, user_info) + semantic_search_result = self.semantic_searcher(learnware_list, user_info) + learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] if len(learnware_list) == 0: - return [], [], 0.0, [] + return SearchResults() if parse_specification_type(stat_specs=user_info.stat_info) is not None: if is_hetero(stat_specs=user_info.stat_info, semantic_spec=user_info.semantic_spec): @@ -49,4 +50,4 @@ class HeteroSearcher(EasySearcher): user_info.update_stat_info(user_hetero_spec.type, user_hetero_spec) return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: - return None, learnware_list, 0.0, None + return semantic_search_result diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index 686fc52..be828e5 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -199,26 +199,28 @@ class TestMarket(unittest.TestCase): semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - assert len(single_learnware_list) == 1, f"Exact semantic search failed!" - for learnware in single_learnware_list: - semantic_spec1 = learnware.get_specification().get_semantic_spec() - print("Choose learnware:", learnware.id, semantic_spec1) + assert len(single_result) == 1, f"Exact semantic search failed!" + for search_item in single_result: + semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() + print("Choose learnware:", search_item.learnware.id, semantic_spec1) assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" semantic_spec["Name"]["Values"] = "laernwaer" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!" - for learnware in single_learnware_list: - semantic_spec1 = learnware.get_specification().get_semantic_spec() - print("Choose learnware:", learnware.id, semantic_spec1) + assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" + for search_item in single_result: + semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() + print("Choose learnware:", search_item.learnware.id, semantic_spec1) def test_stat_search(self, learnware_num=5): hetero_market = self.test_train_market_model(learnware_num) @@ -256,49 +258,40 @@ class TestMarket(unittest.TestCase): semantic_spec["Input"]["Description"] = { str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) } - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - + + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print( - f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}" - ) + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print( + f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" + ) # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec print(">> test for key 'Task' has empty 'Values':") semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" print(">> delele key 'Task' test:") semantic_spec.pop("Task") user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." print(">> mismatch dim test") @@ -310,14 +303,10 @@ class TestMarket(unittest.TestCase): } user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" rmtree(test_folder) # rm -r test_folder @@ -338,21 +327,19 @@ class TestMarket(unittest.TestCase): user_spec = RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "stat.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - - target_spec_num = 3 if idx % 2 == 0 else 2 - assert len(single_learnware_list) >= 1, f"Statistical search failed!" + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + + assert len(single_result) >= 1, f"Statistical search failed!" print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}\n") - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_learnware: {mixture_id}\n") + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print(f"mixture_score: {multiple_item.score}\n") + mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) + print(f"mixture_learnware: {mixture_id}\n") rmtree(test_folder) # rm -r test_folder @@ -370,26 +357,24 @@ class TestMarket(unittest.TestCase): # learnware market search hetero_market = self.test_train_market_model(learnware_num) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() # print search results - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}") + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}") # single model reuse - hetero_learnware = HeteroMapAlignLearnware(single_learnware_list[0], mode="regression") + hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") hetero_learnware.align(user_spec, X[:100], y[:100]) single_predict_y = hetero_learnware.predict(X) # multi model reuse hetero_learnware_list = [] - for learnware in mixture_learnware_list: + for learnware in multiple_result[0].learnwares: hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") hetero_learnware.align(user_spec, X[:100], y[:100]) hetero_learnware_list.append(hetero_learnware) diff --git a/tests/test_learnware_client/test_all_learnware.py b/tests/test_learnware_client/test_all_learnware.py index 2303089..276ac00 100644 --- a/tests/test_learnware_client/test_all_learnware.py +++ b/tests/test_learnware_client/test_all_learnware.py @@ -6,6 +6,7 @@ import tempfile from learnware.client import LearnwareClient from learnware.specification import Specification +from learnware.market import BaseUserInfo class TestAllLearnware(unittest.TestCase): @@ -30,16 +31,9 @@ class TestAllLearnware(unittest.TestCase): def test_all_learnware(self): max_learnware_num = 1000 - semantic_spec = dict() - semantic_spec["Data"] = {"Type": "Class", "Values": []} - semantic_spec["Task"] = {"Type": "Class", "Values": []} - semantic_spec["Library"] = {"Type": "Class", "Values": []} - semantic_spec["Scenario"] = {"Type": "Tag", "Values": []} - semantic_spec["Name"] = {"Type": "String", "Values": ""} - semantic_spec["Description"] = {"Type": "String", "Values": ""} - - specification = Specification(semantic_spec=semantic_spec) - result = self.client.search_learnware(specification, page_size=max_learnware_num) + semantic_spec = self.client.create_semantic_specification() + user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) + result = self.client.search_learnware(user_info, page_size=max_learnware_num) print(f"result size: {len(result)}") print(f"key in result: {[key for key in result[0]]}") diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 38ed69f..0702a13 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -143,12 +143,13 @@ class TestWorkflow(unittest.TestCase): semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) - + search_result = easy_market.search_learnware(user_info) + single_result = search_result.get_single_results() + print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - for learnware in single_learnware_list: - print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) + for search_item in single_result: + print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec()) rmtree(test_folder) # rm -r test_folder @@ -171,20 +172,20 @@ class TestWorkflow(unittest.TestCase): user_spec = RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = easy_market.search_learnware(user_info) - - assert len(single_learnware_list) >= 1, f"Statistical search failed!" + search_results = easy_market.search_learnware(user_info) + + single_result = search_results.get_single_results() + multiple_result = search_results.get_multiple_results() + + assert len(single_result) >= 1, f"Statistical search failed!" print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}\n") - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_learnware: {mixture_id}\n") + for search_item in single_result: + print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") + + for mixture_item in multiple_result: + print(f"mixture_score: {mixture_item.score}\n") + mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) + print(f"mixture_learnware: {mixture_id}\n") rmtree(test_folder) # rm -r test_folder @@ -198,24 +199,25 @@ class TestWorkflow(unittest.TestCase): stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) - _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) - + search_results = easy_market.search_learnware(user_info) + multiple_result = search_results.get_multiple_results() + mixture_item = multiple_result[0] # Based on user information, the learnware market returns a list of learnwares (learnware_list) # Use jobselector reuser to reuse the searched learnwares to make prediction - reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) + reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares) job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) # Use averaging ensemble reuser to reuse the searched learnwares to make prediction - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") + reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) # Use ensemble pruning reuser to reuse the searched learnwares to make prediction - reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") + reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification") reuse_ensemble.fit(train_X[-200:], train_y[-200:]) ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X) # Use feature augment reuser to reuse the searched learnwares to make prediction - reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_learnware_list, mode="classification") + reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification") reuse_feature_augment.fit(train_X[-200:], train_y[-200:]) feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X) @@ -227,8 +229,8 @@ class TestWorkflow(unittest.TestCase): def suite(): _suite = unittest.TestSuite() - _suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) - _suite.addTest(TestWorkflow("test_upload_delete_learnware")) + #_suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) + #_suite.addTest(TestWorkflow("test_upload_delete_learnware")) _suite.addTest(TestWorkflow("test_search_semantics")) _suite.addTest(TestWorkflow("test_stat_search")) _suite.addTest(TestWorkflow("test_learnware_reuse"))