From 1d5bbce2ff53a1a7b584ea2be674478bab2a45c0 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 25 Nov 2023 22:10:06 +0800 Subject: [PATCH] [MNT] now the test workflow is runnable --- learnware/market/base.py | 47 ++++++++++----------- learnware/market/easy/searcher.py | 18 ++++---- learnware/market/heterogeneous/searcher.py | 2 +- tests/test_hetero_market/test_hetero.py | 22 +++++----- tests/test_workflow/test_workflow.py | 49 ++++++++++++---------- 5 files changed, 68 insertions(+), 70 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index a71dad0..d584696 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -57,36 +57,31 @@ class BaseUserInfo: @dataclass -class SearchItem: - learnwares: Union[List[Learnware] ,Learnware] - score: float = 1.0 +class SingleSearchItem: + learnware: Learnware + score: Optional[float] = None +@dataclass +class MultipleSearchItem: + learnwares: List[Learnware] + score: float + class SearchResults: - def __init__(self, single_results: Optional[List[SearchItem]] = None, multiple_results: Optional[List[SearchItem]] = None): - self.search_mapping: Dict[str, List[SearchItem]] = { - "single": [] if single_results is None else single_results, - "multiple": [] if multiple_results is None else multiple_results, - } - - def get_results(self, name): - if name in self.search_mapping: - return self.search_mapping[name] - else: - raise KeyError(f"name '{name}' is not supported for search results") + def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): + self.update_single_results([] if single_results is None else single_results) + self.update_multiple_results([] if multiple_results is None else multiple_results) + + def get_single_results(self) -> List[SingleSearchItem]: + return self.single_results - def update_results(self, name, search_result: List[SearchItem]): - if name in self.search_mapping: - self.search_mapping[name] = search_result - else: - raise KeyError(f"name '{name}' is not supported for search results") + def get_multiple_results(self) -> List[MultipleSearchItem]: + return self.multiple_results - def sort(self, name=None, ascending: bool=False): - assert name is None or name in self.search_mapping, f"name '{name}' is not supported for search results" - for key, search_result in self.search_mapping.items(): - if name is None or name == key: - #search_result = [x for x in search_result] - self.search_mapping[key] = sorted(search_result, key=lambda x:x.score, reverse=(not ascending)) - + def update_single_results(self, single_results: List[SingleSearchItem]): + self.single_results = single_results + + def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): + self.multiple_results = multiple_results class LearnwareMarket: """Base interface for market, it provide the interface of search/add/detele/update learnwares""" diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index bd5fa93..99e1e2d 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -6,7 +6,7 @@ from typing import Tuple, List, Union, Optional from .organizer import EasyOrganizer from ..utils import parse_specification_type -from ..base import BaseUserInfo, BaseSearcher, SearchResults, SearchItem +from ..base import BaseUserInfo, BaseSearcher, SearchResults, SingleSearchItem, MultipleSearchItem from ...learnware import Learnware from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification, rkme_solve_qp from ...logger import get_module_logger @@ -65,7 +65,7 @@ class EasyExactSemanticSearcher(BaseSearcher): if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec): match_learnwares.append(learnware) logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) - return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in match_learnwares]) + return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares]) class EasyFuzzSemanticSearcher(BaseSearcher): def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: @@ -181,7 +181,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): final_result = matched_learnware_tag logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) - return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in final_result]) + return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result]) class EasyStatSearcher(BaseSearcher): @@ -623,14 +623,12 @@ class EasyStatSearcher(BaseSearcher): search_results = SearchResults() - search_results.update_results( - name="single", - search_result=[SearchItem(learnwares=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] + search_results.update_single_results( + [SingleSearchItem(learnware=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] ) if mixture_score is not None and len(mixture_learnware_list) > 0: - search_results.update_results( - name="multiple", - search_result=[SearchItem(learnwares=mixture_learnware_list, score=mixture_score)] + search_results.update_multiple_results( + [MultipleSearchItem(learnwares=mixture_learnware_list, score=mixture_score)] ) return search_results @@ -672,7 +670,7 @@ class EasySearcher(BaseSearcher): learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) semantic_search_result = self.semantic_searcher(learnware_list, user_info) - learnware_list = semantic_search_result.get_results(name="single") + learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] if len(learnware_list) == 0: return SearchResults() diff --git a/learnware/market/heterogeneous/searcher.py b/learnware/market/heterogeneous/searcher.py index 4e806bd..f261703 100644 --- a/learnware/market/heterogeneous/searcher.py +++ b/learnware/market/heterogeneous/searcher.py @@ -40,7 +40,7 @@ class HeteroSearcher(EasySearcher): learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) semantic_search_result = self.semantic_searcher(learnware_list, user_info) - learnware_list = semantic_search_result.get_results(name="single") + learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] if len(learnware_list) == 0: return SearchResults() diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index 686fc52..a98ab60 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -199,26 +199,28 @@ class TestMarket(unittest.TestCase): semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - assert len(single_learnware_list) == 1, f"Exact semantic search failed!" - for learnware in single_learnware_list: - semantic_spec1 = learnware.get_specification().get_semantic_spec() - print("Choose learnware:", learnware.id, semantic_spec1) + assert len(single_result) == 1, f"Exact semantic search failed!" + for search_item in single_result: + semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() + print("Choose learnware:", search_item.learnware.id, semantic_spec1) assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" semantic_spec["Name"]["Values"] = "laernwaer" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!" - for learnware in single_learnware_list: - semantic_spec1 = learnware.get_specification().get_semantic_spec() - print("Choose learnware:", learnware.id, semantic_spec1) + assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" + for search_item in single_result: + semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() + print("Choose learnware:", search_item.learnware.id, semantic_spec1) def test_stat_search(self, learnware_num=5): hetero_market = self.test_train_market_model(learnware_num) diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 38ed69f..07d5456 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -143,12 +143,13 @@ class TestWorkflow(unittest.TestCase): semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" user_info = BaseUserInfo(semantic_spec=semantic_spec) - _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) - + search_result = easy_market.search_learnware(user_info) + single_result = search_result.get_single_results() + print("User info:", user_info.get_semantic_spec()) print(f"Search result:") - for learnware in single_learnware_list: - print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) + for search_item in single_result: + print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec()) rmtree(test_folder) # rm -r test_folder @@ -171,20 +172,21 @@ class TestWorkflow(unittest.TestCase): user_spec = RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = easy_market.search_learnware(user_info) - - assert len(single_learnware_list) >= 1, f"Statistical search failed!" + search_results = easy_market.search_learnware(user_info) + + single_result = search_results.get_single_results() + multiple_result = search_results.get_multiple_results() + + assert len(single_result) >= 1, f"Statistical search failed!" print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}\n") - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_learnware: {mixture_id}\n") + for search_item in single_result: + print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") + + if len(multiple_result) > 0: + mixture_item = multiple_result[0] + print(f"mixture_score: {mixture_item.score}\n") + mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) + print(f"mixture_learnware: {mixture_id}\n") rmtree(test_folder) # rm -r test_folder @@ -198,24 +200,25 @@ class TestWorkflow(unittest.TestCase): stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) - _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) - + search_results = easy_market.search_learnware(user_info) + multiple_result = search_results.get_multiple_results() + mixture_item = multiple_result[0] # Based on user information, the learnware market returns a list of learnwares (learnware_list) # Use jobselector reuser to reuse the searched learnwares to make prediction - reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) + reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares) job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) # Use averaging ensemble reuser to reuse the searched learnwares to make prediction - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") + reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) # Use ensemble pruning reuser to reuse the searched learnwares to make prediction - reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") + reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification") reuse_ensemble.fit(train_X[-200:], train_y[-200:]) ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X) # Use feature augment reuser to reuse the searched learnwares to make prediction - reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_learnware_list, mode="classification") + reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification") reuse_feature_augment.fit(train_X[-200:], train_y[-200:]) feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X)