Browse Source

[MNT] Convert dist_list to score_list

tags/v0.3.2
liuht 3 years ago
parent
commit
15b2e4824e
2 changed files with 24 additions and 5 deletions
  1. +3
    -3
      examples/example_market_db/example_db.py
  2. +21
    -2
      learnware/market/easy.py

+ 3
- 3
examples/example_market_db/example_db.py View File

@@ -159,11 +159,11 @@ def test_stat_search():
user_info = BaseUserInfo(
id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
)
sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)

print(f"search result of user{idx}:")
for dist, learnware in zip(sorted_dist_list, single_learnware_list):
print(f"dist: {dist}, learnware_id: {learnware.id}")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")



+ 21
- 2
learnware/market/easy.py View File

@@ -123,6 +123,24 @@ class EasyMarket(BaseMarket):
folder_path=target_folder_dir,
)
return id, True
def _convert_dist_to_score(self, dist_list: List[float]) -> List[float]:
"""Convert mmd dist list into min_max score list

Parameters
----------
dist_list : List[float]
The list of mmd distances from learnware rkmes to user rkme

Returns
-------
List[float]
The list of min_max scores of each learnware
"""
score_list = [(max(dist_list) - dist) / (max(dist_list) - min(dist_list)) for dist in dist_list]

return score_list

def _calculate_rkme_spec_mixture_weight(
self,
@@ -244,7 +262,7 @@ class EasyMarket(BaseMarket):
learnware_num = len(learnware_list)
if learnware_num == 0:
return [], []
if learnware_num < search_num:
if learnware_num < search_num:
logger.warning("Available Learnware num less than search_num")
search_num = learnware_num

@@ -382,10 +400,11 @@ class EasyMarket(BaseMarket):
else:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
sorted_score_list = self._convert_dist_to_score(sorted_dist_list)
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(
learnware_list, user_rkme, search_num
)
return sorted_dist_list, single_learnware_list, mixture_learnware_list
return sorted_score_list, single_learnware_list, mixture_learnware_list

def delete_learnware(self, id: str) -> bool:
"""Delete Learnware from market


Loading…
Cancel
Save