| @@ -57,36 +57,31 @@ class BaseUserInfo: | |||
| @dataclass | |||
| class SearchItem: | |||
| learnwares: Union[List[Learnware] ,Learnware] | |||
| score: float = 1.0 | |||
| class SingleSearchItem: | |||
| learnware: Learnware | |||
| score: Optional[float] = None | |||
| @dataclass | |||
| class MultipleSearchItem: | |||
| learnwares: List[Learnware] | |||
| score: float | |||
| class SearchResults: | |||
| def __init__(self, single_results: Optional[List[SearchItem]] = None, multiple_results: Optional[List[SearchItem]] = None): | |||
| self.search_mapping: Dict[str, List[SearchItem]] = { | |||
| "single": [] if single_results is None else single_results, | |||
| "multiple": [] if multiple_results is None else multiple_results, | |||
| } | |||
| def get_results(self, name): | |||
| if name in self.search_mapping: | |||
| return self.search_mapping[name] | |||
| else: | |||
| raise KeyError(f"name '{name}' is not supported for search results") | |||
| def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): | |||
| self.update_single_results([] if single_results is None else single_results) | |||
| self.update_multiple_results([] if multiple_results is None else multiple_results) | |||
| def get_single_results(self) -> List[SingleSearchItem]: | |||
| return self.single_results | |||
| def update_results(self, name, search_result: List[SearchItem]): | |||
| if name in self.search_mapping: | |||
| self.search_mapping[name] = search_result | |||
| else: | |||
| raise KeyError(f"name '{name}' is not supported for search results") | |||
| def get_multiple_results(self) -> List[MultipleSearchItem]: | |||
| return self.multiple_results | |||
| def sort(self, name=None, ascending: bool=False): | |||
| assert name is None or name in self.search_mapping, f"name '{name}' is not supported for search results" | |||
| for key, search_result in self.search_mapping.items(): | |||
| if name is None or name == key: | |||
| #search_result = [x for x in search_result] | |||
| self.search_mapping[key] = sorted(search_result, key=lambda x:x.score, reverse=(not ascending)) | |||
| def update_single_results(self, single_results: List[SingleSearchItem]): | |||
| self.single_results = single_results | |||
| def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): | |||
| self.multiple_results = multiple_results | |||
| class LearnwareMarket: | |||
| """Base interface for market, it provide the interface of search/add/detele/update learnwares""" | |||
| @@ -6,7 +6,7 @@ from typing import Tuple, List, Union, Optional | |||
| from .organizer import EasyOrganizer | |||
| from ..utils import parse_specification_type | |||
| from ..base import BaseUserInfo, BaseSearcher, SearchResults, SearchItem | |||
| from ..base import BaseUserInfo, BaseSearcher, SearchResults, SingleSearchItem, MultipleSearchItem | |||
| from ...learnware import Learnware | |||
| from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification, rkme_solve_qp | |||
| from ...logger import get_module_logger | |||
| @@ -65,7 +65,7 @@ class EasyExactSemanticSearcher(BaseSearcher): | |||
| if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec): | |||
| match_learnwares.append(learnware) | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) | |||
| return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in match_learnwares]) | |||
| return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares]) | |||
| class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: | |||
| @@ -181,7 +181,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| final_result = matched_learnware_tag | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) | |||
| return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in final_result]) | |||
| return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result]) | |||
| class EasyStatSearcher(BaseSearcher): | |||
| @@ -623,14 +623,12 @@ class EasyStatSearcher(BaseSearcher): | |||
| search_results = SearchResults() | |||
| search_results.update_results( | |||
| name="single", | |||
| search_result=[SearchItem(learnwares=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] | |||
| search_results.update_single_results( | |||
| [SingleSearchItem(learnware=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] | |||
| ) | |||
| if mixture_score is not None and len(mixture_learnware_list) > 0: | |||
| search_results.update_results( | |||
| name="multiple", | |||
| search_result=[SearchItem(learnwares=mixture_learnware_list, score=mixture_score)] | |||
| search_results.update_multiple_results( | |||
| [MultipleSearchItem(learnwares=mixture_learnware_list, score=mixture_score)] | |||
| ) | |||
| return search_results | |||
| @@ -672,7 +670,7 @@ class EasySearcher(BaseSearcher): | |||
| learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) | |||
| semantic_search_result = self.semantic_searcher(learnware_list, user_info) | |||
| learnware_list = semantic_search_result.get_results(name="single") | |||
| learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] | |||
| if len(learnware_list) == 0: | |||
| return SearchResults() | |||
| @@ -40,7 +40,7 @@ class HeteroSearcher(EasySearcher): | |||
| learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) | |||
| semantic_search_result = self.semantic_searcher(learnware_list, user_info) | |||
| learnware_list = semantic_search_result.get_results(name="single") | |||
| learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] | |||
| if len(learnware_list) == 0: | |||
| return SearchResults() | |||
| @@ -199,26 +199,28 @@ class TestMarket(unittest.TestCase): | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == 1, f"Exact semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| assert len(single_result) == 1, f"Exact semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" | |||
| semantic_spec["Name"]["Values"] = "laernwaer" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| def test_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| @@ -143,12 +143,13 @@ class TestWorkflow(unittest.TestCase): | |||
| semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| for learnware in single_learnware_list: | |||
| print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) | |||
| for search_item in single_result: | |||
| print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec()) | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -171,20 +172,21 @@ class TestWorkflow(unittest.TestCase): | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| assert len(single_learnware_list) >= 1, f"Statistical search failed!" | |||
| search_results = easy_market.search_learnware(user_info) | |||
| single_result = search_results.get_single_results() | |||
| multiple_result = search_results.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| for search_item in single_result: | |||
| print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") | |||
| if len(multiple_result) > 0: | |||
| mixture_item = multiple_result[0] | |||
| print(f"mixture_score: {mixture_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -198,24 +200,25 @@ class TestWorkflow(unittest.TestCase): | |||
| stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | |||
| search_results = easy_market.search_learnware(user_info) | |||
| multiple_result = search_results.get_multiple_results() | |||
| mixture_item = multiple_result[0] | |||
| # Based on user information, the learnware market returns a list of learnwares (learnware_list) | |||
| # Use jobselector reuser to reuse the searched learnwares to make prediction | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| # Use ensemble pruning reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification") | |||
| reuse_ensemble.fit(train_X[-200:], train_y[-200:]) | |||
| ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| # Use feature augment reuser to reuse the searched learnwares to make prediction | |||
| reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_learnware_list, mode="classification") | |||
| reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification") | |||
| reuse_feature_augment.fit(train_X[-200:], train_y[-200:]) | |||
| feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X) | |||