From 44c63fc2a8fa0bb12f97c8aef94889bcae3ce9f2 Mon Sep 17 00:00:00 2001 From: liuht Date: Tue, 18 Apr 2023 20:55:45 +0800 Subject: [PATCH] [MNT] Add another search multiple learnware method --- learnware/market/easy.py | 52 ++++++++++++++++++++++++++++++++++++++-- setup.py | 1 + 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 0b4630a..3d6c00b 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -276,6 +276,55 @@ class EasyMarket(BaseMarket): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) return intermediate_K, intermediate_C + + def _search_by_rkme_spec_mixture_auto( + self, + learnware_list: List[Learnware], + user_rkme: RKMEStatSpecification, + max_search_num: int, + weight_cutoff: float = 0.9 + ) -> Tuple[List[float], List[Learnware]]: + """Select learnwares based on a total mixture ratio, then recalculate their mixture weights + + Parameters + ---------- + learnware_list : List[Learnware] + The list of learnwares whose mixture approximates the user's rkme + user_rkme : RKMEStatSpecification + User RKME statistical specification + max_search_num : int + The maximum number of the returned learnwares + weight_cutoff : float, optional + The ratio for selecting out the mose relevant learnwares, by default 0.9 + + Returns + ------- + Tuple[List[float], List[Learnware]] + The first is the list of weight + The second is the list of Learnware + """ + learnware_num = len(learnware_list) + if learnware_num == 0: + return [], [] + if learnware_num < max_search_num: + logger.warning("Available Learnware num less than search_num!") + max_search_num = learnware_num + + weight, _ = self._calculate_rkme_spec_mixture_weight(learnware_list, user_rkme) + sort_by_weight_idx_list = sorted(range(learnware_num), key=lambda k: weight[k]) + + weight_sum = 0 + mixture_list = [] + for idx in sort_by_weight_idx_list: + weight_sum += sort_by_weight_idx_list[idx] + if weight_sum <= weight_cutoff: + mixture_list.append(learnware_list[idx]) + + if len(mixture_list) > max_search_num: + mixture_list = mixture_list[:max_search_num] + + mixture_weight, _ = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme) + return mixture_weight, mixture_list def _search_by_rkme_spec_mixture( self, @@ -284,7 +333,7 @@ class EasyMarket(BaseMarket): max_search_num: int, score_cutoff: float = 0.01, ) -> Tuple[List[float], List[Learnware]]: - """Get learnwares with their mixture weight from the given learnware_list + """Greedily match learnwares such that their mixture become more and more closer to user's rkme Parameters ---------- @@ -338,7 +387,6 @@ class EasyMarket(BaseMarket): mixture_list[-1] = learnware_list[idx_min] if score_min < score_cutoff: - print(score_min) break else: flag_list[idx_min] = 1 diff --git a/setup.py b/setup.py index b63ce24..144f4cb 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ REQUIRED = [ "fire>=0.3.1", "lightgbm>=3.3.0", "psutil>=5.9.4", + "torchvision>=0.15.1" ] here = os.path.abspath(os.path.dirname(__file__))