diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index eb43809..87dc16a 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -1,3 +1,4 @@ +import math import torch import numpy as np from rapidfuzz import fuzz @@ -186,7 +187,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): class EasyStatSearcher(BaseSearcher): def _convert_dist_to_score( - self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 + self, dist_list: List[float], dist_ratio: float = 0.5, min_score: float = 0.92, improve_score: float = 0.7 ) -> List[float]: """Convert mmd dist list into min_max score list @@ -194,10 +195,12 @@ class EasyStatSearcher(BaseSearcher): ---------- dist_list : List[float] The list of mmd distances from learnware rkmes to user rkme - dist_epsilon: float + dist_ratio: float The paramter for converting mmd dist to score min_score: float The minimum score for maximum returned score + improve_score: float + The learnware score lower than improve_score will be improved Returns ------- @@ -211,6 +214,7 @@ class EasyStatSearcher(BaseSearcher): if min_dist == max_dist: return [1 for dist in dist_list] else: + dist_epsilon = max_dist * dist_ratio max_score = (max_dist - min_dist) / (max_dist - dist_epsilon) if min_dist < dist_epsilon: @@ -218,7 +222,14 @@ class EasyStatSearcher(BaseSearcher): elif max_score < min_score: dist_epsilon = max_dist - (max_dist - min_dist) / min_score - return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list] + score_list = [] + for dist in dist_list: + score = (max_dist - dist) / (max_dist - dist_epsilon) + if score < improve_score: + score = min(math.sqrt(score), improve_score) + score_list.append(score) + + return score_list def _calculate_rkme_spec_mixture_weight( self, @@ -371,8 +382,8 @@ class EasyStatSearcher(BaseSearcher): self, sorted_score_list: List[float], learnware_list: List[Learnware], - filter_score: float = 0.5, - min_num: int = 15, + filter_score: float = 0.6, + min_num: int = 1, ) -> Tuple[List[float], List[Learnware]]: """Filter search result of _search_by_rkme_spec_single @@ -442,7 +453,7 @@ class EasyStatSearcher(BaseSearcher): learnware_list: List[Learnware], user_rkme: RKMETableSpecification, max_search_num: int, - score_cutoff: float = 0.001, + decay_rate: float = 0.95, ) -> Tuple[float, List[float], List[Learnware]]: """Greedily match learnwares such that their mixture become closer and closer to user's rkme @@ -454,8 +465,8 @@ class EasyStatSearcher(BaseSearcher): User RKME statistical specification max_search_num : int The maximum number of the returned learnwares - score_cutof: float - The minimum mmd dist as threshold to stop further rkme_spec matching + decay_rate: float + The decrease ratio of minimum mmd dist to stop further rkme_spec matching Returns ------- @@ -472,11 +483,11 @@ class EasyStatSearcher(BaseSearcher): max_search_num = learnware_num flag_list = [0 for _ in range(learnware_num)] - mixture_list, mmd_dist = [], None + mixture_list, weight_list, mmd_dist = [], None, None intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) for k in range(max_search_num): - idx_min, score_min = -1, -1 + idx_min, score_min = None, None weight_min = None mixture_list.append(None) @@ -494,20 +505,21 @@ class EasyStatSearcher(BaseSearcher): weight, score = self._calculate_rkme_spec_mixture_weight( mixture_list, user_rkme, intermediate_K, intermediate_C ) - if idx_min == -1 or score < score_min: + if score_min is None or score < score_min: idx_min, score_min, weight_min = idx, score, weight - mmd_dist = score_min - mixture_list[-1] = learnware_list[idx_min] - if score_min < score_cutoff: - break - else: + if mmd_dist is None or score_min <= mmd_dist * decay_rate: + mmd_dist, weight_list = score_min, weight_min + mixture_list[-1] = learnware_list[idx_min] flag_list[idx_min] = 1 intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( mixture_list, user_rkme, intermediate_K, intermediate_C ) + else: + mixture_list = mixture_list[:-1] + break - return mmd_dist, weight_min, mixture_list + return mmd_dist, weight_list, mixture_list def _search_by_rkme_spec_single( self, @@ -558,15 +570,14 @@ class EasyStatSearcher(BaseSearcher): logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) + processed_learnware_list = single_learnware_list[: max_search_num * max_search_num] if search_method == "auto": mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( - learnware_list, user_rkme, max_search_num + processed_learnware_list, user_rkme, max_search_num ) elif search_method == "greedy": - score_cutoff = sorted_dist_list[0] * 0.05 if \ - len(sorted_dist_list) > 0 and self.stat_spec_type == "RKMEImageSpecification" else 0.001 mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( - learnware_list, user_rkme, max_search_num, score_cutoff=score_cutoff + processed_learnware_list, user_rkme, max_search_num ) else: logger.warning("f{search_method} not supported!") @@ -581,14 +592,23 @@ class EasyStatSearcher(BaseSearcher): merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist]) sorted_score_list = merge_score_list[:-1] mixture_score = merge_score_list[-1] + if int(mixture_score * 100) == int(sorted_score_list[0] * 100): + mixture_score = None + mixture_learnware_list = [] + logger.info( + f"After search by rkme spec, learnware_list length is {len(learnware_list)}, mixture_learnware_list length is {len(mixture_learnware_list)}" + ) - logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}") - # filter learnware with low score + # Filter learnware with low score sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( sorted_score_list, single_learnware_list ) - + if len(single_learnware_list) == 1 and sorted_score_list[0] < 0.6: + ratio = 0.6 / sorted_score_list[0] + sorted_score_list[0] = 0.6 + mixture_score = min(1, mixture_score * ratio) if mixture_score is not None else None logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") + return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list