| @@ -1,3 +1,4 @@ | |||
| import math | |||
| import torch | |||
| import numpy as np | |||
| from rapidfuzz import fuzz | |||
| @@ -186,7 +187,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| class EasyStatSearcher(BaseSearcher): | |||
| def _convert_dist_to_score( | |||
| self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 | |||
| self, dist_list: List[float], dist_ratio: float = 0.5, min_score: float = 0.92, improve_score: float = 0.7 | |||
| ) -> List[float]: | |||
| """Convert mmd dist list into min_max score list | |||
| @@ -194,10 +195,12 @@ class EasyStatSearcher(BaseSearcher): | |||
| ---------- | |||
| dist_list : List[float] | |||
| The list of mmd distances from learnware rkmes to user rkme | |||
| dist_epsilon: float | |||
| dist_ratio: float | |||
| The paramter for converting mmd dist to score | |||
| min_score: float | |||
| The minimum score for maximum returned score | |||
| improve_score: float | |||
| The learnware score lower than improve_score will be improved | |||
| Returns | |||
| ------- | |||
| @@ -211,6 +214,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| if min_dist == max_dist: | |||
| return [1 for dist in dist_list] | |||
| else: | |||
| dist_epsilon = max_dist * dist_ratio | |||
| max_score = (max_dist - min_dist) / (max_dist - dist_epsilon) | |||
| if min_dist < dist_epsilon: | |||
| @@ -218,7 +222,14 @@ class EasyStatSearcher(BaseSearcher): | |||
| elif max_score < min_score: | |||
| dist_epsilon = max_dist - (max_dist - min_dist) / min_score | |||
| return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list] | |||
| score_list = [] | |||
| for dist in dist_list: | |||
| score = (max_dist - dist) / (max_dist - dist_epsilon) | |||
| if score < improve_score: | |||
| score = min(math.sqrt(score), improve_score) | |||
| score_list.append(score) | |||
| return score_list | |||
| def _calculate_rkme_spec_mixture_weight( | |||
| self, | |||
| @@ -371,8 +382,8 @@ class EasyStatSearcher(BaseSearcher): | |||
| self, | |||
| sorted_score_list: List[float], | |||
| learnware_list: List[Learnware], | |||
| filter_score: float = 0.5, | |||
| min_num: int = 15, | |||
| filter_score: float = 0.6, | |||
| min_num: int = 1, | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Filter search result of _search_by_rkme_spec_single | |||
| @@ -442,7 +453,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMETableSpecification, | |||
| max_search_num: int, | |||
| score_cutoff: float = 0.001, | |||
| decay_rate: float = 0.95, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| """Greedily match learnwares such that their mixture become closer and closer to user's rkme | |||
| @@ -454,8 +465,8 @@ class EasyStatSearcher(BaseSearcher): | |||
| User RKME statistical specification | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| score_cutof: float | |||
| The minimum mmd dist as threshold to stop further rkme_spec matching | |||
| decay_rate: float | |||
| The decrease ratio of minimum mmd dist to stop further rkme_spec matching | |||
| Returns | |||
| ------- | |||
| @@ -472,11 +483,11 @@ class EasyStatSearcher(BaseSearcher): | |||
| max_search_num = learnware_num | |||
| flag_list = [0 for _ in range(learnware_num)] | |||
| mixture_list, mmd_dist = [], None | |||
| mixture_list, weight_list, mmd_dist = [], None, None | |||
| intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) | |||
| for k in range(max_search_num): | |||
| idx_min, score_min = -1, -1 | |||
| idx_min, score_min = None, None | |||
| weight_min = None | |||
| mixture_list.append(None) | |||
| @@ -494,20 +505,21 @@ class EasyStatSearcher(BaseSearcher): | |||
| weight, score = self._calculate_rkme_spec_mixture_weight( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| if idx_min == -1 or score < score_min: | |||
| if score_min is None or score < score_min: | |||
| idx_min, score_min, weight_min = idx, score, weight | |||
| mmd_dist = score_min | |||
| mixture_list[-1] = learnware_list[idx_min] | |||
| if score_min < score_cutoff: | |||
| break | |||
| else: | |||
| if mmd_dist is None or score_min <= mmd_dist * decay_rate: | |||
| mmd_dist, weight_list = score_min, weight_min | |||
| mixture_list[-1] = learnware_list[idx_min] | |||
| flag_list[idx_min] = 1 | |||
| intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| else: | |||
| mixture_list = mixture_list[:-1] | |||
| break | |||
| return mmd_dist, weight_min, mixture_list | |||
| return mmd_dist, weight_list, mixture_list | |||
| def _search_by_rkme_spec_single( | |||
| self, | |||
| @@ -558,15 +570,14 @@ class EasyStatSearcher(BaseSearcher): | |||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| processed_learnware_list = single_learnware_list[: max_search_num * max_search_num] | |||
| if search_method == "auto": | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( | |||
| learnware_list, user_rkme, max_search_num | |||
| processed_learnware_list, user_rkme, max_search_num | |||
| ) | |||
| elif search_method == "greedy": | |||
| score_cutoff = sorted_dist_list[0] * 0.05 if \ | |||
| len(sorted_dist_list) > 0 and self.stat_spec_type == "RKMEImageSpecification" else 0.001 | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( | |||
| learnware_list, user_rkme, max_search_num, score_cutoff=score_cutoff | |||
| processed_learnware_list, user_rkme, max_search_num | |||
| ) | |||
| else: | |||
| logger.warning("f{search_method} not supported!") | |||
| @@ -581,14 +592,23 @@ class EasyStatSearcher(BaseSearcher): | |||
| merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist]) | |||
| sorted_score_list = merge_score_list[:-1] | |||
| mixture_score = merge_score_list[-1] | |||
| if int(mixture_score * 100) == int(sorted_score_list[0] * 100): | |||
| mixture_score = None | |||
| mixture_learnware_list = [] | |||
| logger.info( | |||
| f"After search by rkme spec, learnware_list length is {len(learnware_list)}, mixture_learnware_list length is {len(mixture_learnware_list)}" | |||
| ) | |||
| logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| # filter learnware with low score | |||
| # Filter learnware with low score | |||
| sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( | |||
| sorted_score_list, single_learnware_list | |||
| ) | |||
| if len(single_learnware_list) == 1 and sorted_score_list[0] < 0.6: | |||
| ratio = 0.6 / sorted_score_list[0] | |||
| sorted_score_list[0] = 0.6 | |||
| mixture_score = min(1, mixture_score * ratio) if mixture_score is not None else None | |||
| logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list | |||