diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 4a35c8d..bdcbe53 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -159,11 +159,11 @@ def test_stat_search(): user_info = BaseUserInfo( id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec} ) - sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) print(f"search result of user{idx}:") - for dist, learnware in zip(sorted_dist_list, single_learnware_list): - print(f"dist: {dist}, learnware_id: {learnware.id}") + for score, learnware in zip(sorted_score_list, single_learnware_list): + print(f"score: {score}, learnware_id: {learnware.id}") mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) print(f"mixture_learnware: {mixture_id}\n") diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 42827b2..a472a85 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -123,6 +123,24 @@ class EasyMarket(BaseMarket): folder_path=target_folder_dir, ) return id, True + + def _convert_dist_to_score(self, dist_list: List[float]) -> List[float]: + """Convert mmd dist list into min_max score list + + Parameters + ---------- + dist_list : List[float] + The list of mmd distances from learnware rkmes to user rkme + + Returns + ------- + List[float] + The list of min_max scores of each learnware + """ + score_list = [(max(dist_list) - dist) / (max(dist_list) - min(dist_list)) for dist in dist_list] + + return score_list + def _calculate_rkme_spec_mixture_weight( self, @@ -244,7 +262,7 @@ class EasyMarket(BaseMarket): learnware_num = len(learnware_list) if learnware_num == 0: return [], [] - if learnware_num < search_num: + if learnware_num < search_num: logger.warning("Available Learnware num less than search_num") search_num = learnware_num @@ -382,10 +400,11 @@ class EasyMarket(BaseMarket): else: user_rkme = user_info.stat_info["RKMEStatSpecification"] sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) + sorted_score_list = self._convert_dist_to_score(sorted_dist_list) weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture( learnware_list, user_rkme, search_num ) - return sorted_dist_list, single_learnware_list, mixture_learnware_list + return sorted_score_list, single_learnware_list, mixture_learnware_list def delete_learnware(self, id: str) -> bool: """Delete Learnware from market