From 745d129da3e16c37edebfbb225d0a80ca8d9b836 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 19 Oct 2023 16:26:35 +0800 Subject: [PATCH] [ENH] add exact_search before fuzz_search in semantic search --- learnware/market/easy.py | 90 +++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 37 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 8893d92..26f8fb0 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -679,7 +679,7 @@ class EasyMarket(BaseMarket): return match_learnwares def _search_by_semantic_spec_fuzz( - self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 30.0 + self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0 ) -> List[Learnware]: """Search learnware by fuzzy matching of semantic spec @@ -699,9 +699,8 @@ class EasyMarket(BaseMarket): List[Learnware] The list of returned learnwares """ - - def match_semantic_spec_fuzz(semantic_spec1, semantic_spec2) -> float: - """Calculate the fuzzy matching score of two semantic specs + def _match_semantic_spec_tag(semantic_spec1, semantic_spec2) -> bool: + """Judge if tags of two semantic specs are consistent Parameters ---------- @@ -712,8 +711,8 @@ class EasyMarket(BaseMarket): Returns ------- - float - matching score ranged from [0, 100] + bool + consistent (True) or not consistent (False) """ for key in semantic_spec1.keys(): v1 = semantic_spec1[key]["Values"] @@ -726,7 +725,7 @@ class EasyMarket(BaseMarket): if key not in "Name": if len(v2) == 0: # user input contains some key that is not in database - return 0 + return False if semantic_spec1[key]["Type"] == "Class": if isinstance(v1, list): @@ -734,42 +733,59 @@ class EasyMarket(BaseMarket): if isinstance(v2, list): v2 = v2[0] if v1 != v2: - return 0 + return False elif semantic_spec1[key]["Type"] == "Tag": if not (set(v1) & set(v2)): - return 0 - pass - pass - pass - - name2 = semantic_spec2["Name"]["Values"].lower() - description2 = semantic_spec2["Description"]["Values"].lower() - - if "Name" in semantic_spec1: - name1 = semantic_spec1["Name"]["Values"].lower() - if len(name1) > 0: - score_name = fuzz.partial_ratio(name1, name2) - score_des = fuzz.partial_ratio(name1, description2) - return score_name * 0.7 + score_des * 0.3 - - return 100 + return False + return True + + matched_learnware_tag = [] + final_result = [] + user_semantic_spec = user_info.get_semantic_spec() - matched_learnwares, matched_scores = [], [] for learnware in learnware_list: learnware_semantic_spec = learnware.get_specification().get_semantic_spec() - user_semantic_spec = user_info.get_semantic_spec() - match_score = match_semantic_spec_fuzz(user_semantic_spec, learnware_semantic_spec) - if match_score >= min_score: - matched_learnwares.append(learnware) - matched_scores.append(match_score) - - sort_idx = sorted(list(range(len(matched_scores))), key=lambda k: matched_scores[k], reverse=True)[:max_num] - matched_learnwares = [matched_learnwares[idx] for idx in sort_idx] + if _match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec): + matched_learnware_tag.append(learnware) + + if len(matched_learnware_tag) > 0: + if "Name" in user_semantic_spec: + name_user = user_semantic_spec["Name"]["Values"].lower() + if len(name_user) > 0: + # Exact search + name_list = [learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() for learnware in matched_learnware_tag] + des_list = [learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() for learnware in matched_learnware_tag] + + matched_learnware_exact = [] + for i in range(len(name_list)): + if name_user in name_list[i] or name_user in des_list[i]: + matched_learnware_exact.append(matched_learnware_tag[i]) + + if len(matched_learnware_exact) == 0: + # Fuzzy search + matched_learnware_fuzz, fuzz_scores = [], [] + for i in range(len(name_list)): + score_name = fuzz.partial_ratio(name_user, name_list[i]) + score_des = fuzz.partial_ratio(name_user, des_list[i]) + final_score = max(score_name, score_des) + if final_score >= min_score: + matched_learnware_fuzz.append(matched_learnware_tag[i]) + fuzz_scores.append(final_score) + + # Sort by score + sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[:max_num] + final_result = [matched_learnware_fuzz[idx] for idx in sort_idx] + else: + final_result = matched_learnware_exact + else: + final_result = matched_learnware_tag + else: + final_result = matched_learnware_tag logger.info( - "semantic_spec search: choose %d from %d learnwares" % (len(matched_learnwares), len(learnware_list)) + "semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list)) ) - return matched_learnwares + return final_result def search_learnware( self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" @@ -792,8 +808,9 @@ class EasyMarket(BaseMarket): the fourth is the list of Learnware (mixture), the size is search_num """ learnware_list = [self.learnware_list[key] for key in self.learnware_list] + # learnware_list = self._search_by_semantic_spec_exact(learnware_list, user_info) + # if len(learnware_list) == 0: learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) - # learnware_list = list(set(learnware_list_tags + learnware_list_description)) if "RKMEStatSpecification" not in user_info.stat_info: return None, learnware_list, 0.0, None @@ -910,7 +927,6 @@ class EasyMarket(BaseMarket): str: id of targer learware List[str]: A list of ids of target learnwares - Returns ------- Union[Learnware, List[Learnware]]