From 9ae6ff6972f46c4f0bd23028e819f2ce72a2571b Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 18 Apr 2023 21:31:18 +0800 Subject: [PATCH] [MNT] Add _filter_by_rkme_spec_single in EasyMarket --- examples/example_image/main.py | 4 ++-- learnware/market/easy.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/examples/example_image/main.py b/examples/example_image/main.py index 48b4a26..1a65272 100644 --- a/examples/example_image/main.py +++ b/examples/example_image/main.py @@ -183,6 +183,6 @@ def test_search(load_market=True): if __name__ == "__main__": - # prepare_data() - # prepare_model() + prepare_data() + prepare_model() test_search() diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 0b4630a..25bb2f8 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -277,6 +277,29 @@ class EasyMarket(BaseMarket): intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) return intermediate_K, intermediate_C + def _filter_by_rkme_spec_single(self, sorted_score_list: List[float], learnware_list: List[Learnware], filter_score=60, min_num=15) -> Tuple[List[float], List[Learnware]]: + """Filter search result of _search_by_rkme_spec_single + + Parameters + ---------- + sorted_score_list : List[float] + The list of score transformed by mmd dist + learnware_list : List[Learnware] + The list of learnwares whose mixture approximates the user's rkme + + Returns + ------- + Tuple[List[float], List[Learnware]] + the first is the list of score + the second is the list of Learnware + """ + idx = min(min_num, len(learnware_list)) + while idx < len(learnware_list): + if sorted_score_list[idx] < filter_score: + break + idx = idx + 1 + return sorted_score_list[:idx], learnware_list[:idx] + def _search_by_rkme_spec_mixture( self, learnware_list: List[Learnware], @@ -448,6 +471,7 @@ class EasyMarket(BaseMarket): user_rkme = user_info.stat_info["RKMEStatSpecification"] sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) sorted_score_list = self._convert_dist_to_score(sorted_dist_list) + sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single(sorted_score_list, single_learnware_list) weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture( learnware_list, user_rkme, max_search_num )