Browse Source

[MNT] modify score function in EasyStatSearcher

tags/v0.3.2
Gene 2 years ago
parent
commit
f9abb48ae5
1 changed files with 44 additions and 24 deletions
  1. +44
    -24
      learnware/market/easy/searcher.py

+ 44
- 24
learnware/market/easy/searcher.py View File

@@ -1,3 +1,4 @@
import math
import torch
import numpy as np
from rapidfuzz import fuzz
@@ -186,7 +187,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher):

class EasyStatSearcher(BaseSearcher):
def _convert_dist_to_score(
self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92
self, dist_list: List[float], dist_ratio: float = 0.5, min_score: float = 0.92, improve_score: float = 0.7
) -> List[float]:
"""Convert mmd dist list into min_max score list

@@ -194,10 +195,12 @@ class EasyStatSearcher(BaseSearcher):
----------
dist_list : List[float]
The list of mmd distances from learnware rkmes to user rkme
dist_epsilon: float
dist_ratio: float
The paramter for converting mmd dist to score
min_score: float
The minimum score for maximum returned score
improve_score: float
The learnware score lower than improve_score will be improved

Returns
-------
@@ -211,6 +214,7 @@ class EasyStatSearcher(BaseSearcher):
if min_dist == max_dist:
return [1 for dist in dist_list]
else:
dist_epsilon = max_dist * dist_ratio
max_score = (max_dist - min_dist) / (max_dist - dist_epsilon)

if min_dist < dist_epsilon:
@@ -218,7 +222,14 @@ class EasyStatSearcher(BaseSearcher):
elif max_score < min_score:
dist_epsilon = max_dist - (max_dist - min_dist) / min_score

return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list]
score_list = []
for dist in dist_list:
score = (max_dist - dist) / (max_dist - dist_epsilon)
if score < improve_score:
score = min(math.sqrt(score), improve_score)
score_list.append(score)

return score_list

def _calculate_rkme_spec_mixture_weight(
self,
@@ -371,8 +382,8 @@ class EasyStatSearcher(BaseSearcher):
self,
sorted_score_list: List[float],
learnware_list: List[Learnware],
filter_score: float = 0.5,
min_num: int = 15,
filter_score: float = 0.6,
min_num: int = 1,
) -> Tuple[List[float], List[Learnware]]:
"""Filter search result of _search_by_rkme_spec_single

@@ -442,7 +453,7 @@ class EasyStatSearcher(BaseSearcher):
learnware_list: List[Learnware],
user_rkme: RKMETableSpecification,
max_search_num: int,
score_cutoff: float = 0.001,
decay_rate: float = 0.95,
) -> Tuple[float, List[float], List[Learnware]]:
"""Greedily match learnwares such that their mixture become closer and closer to user's rkme

@@ -454,8 +465,8 @@ class EasyStatSearcher(BaseSearcher):
User RKME statistical specification
max_search_num : int
The maximum number of the returned learnwares
score_cutof: float
The minimum mmd dist as threshold to stop further rkme_spec matching
decay_rate: float
The decrease ratio of minimum mmd dist to stop further rkme_spec matching

Returns
-------
@@ -472,11 +483,11 @@ class EasyStatSearcher(BaseSearcher):
max_search_num = learnware_num

flag_list = [0 for _ in range(learnware_num)]
mixture_list, mmd_dist = [], None
mixture_list, weight_list, mmd_dist = [], None, None
intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1))

for k in range(max_search_num):
idx_min, score_min = -1, -1
idx_min, score_min = None, None
weight_min = None
mixture_list.append(None)

@@ -494,20 +505,21 @@ class EasyStatSearcher(BaseSearcher):
weight, score = self._calculate_rkme_spec_mixture_weight(
mixture_list, user_rkme, intermediate_K, intermediate_C
)
if idx_min == -1 or score < score_min:
if score_min is None or score < score_min:
idx_min, score_min, weight_min = idx, score, weight

mmd_dist = score_min
mixture_list[-1] = learnware_list[idx_min]
if score_min < score_cutoff:
break
else:
if mmd_dist is None or score_min <= mmd_dist * decay_rate:
mmd_dist, weight_list = score_min, weight_min
mixture_list[-1] = learnware_list[idx_min]
flag_list[idx_min] = 1
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(
mixture_list, user_rkme, intermediate_K, intermediate_C
)
else:
mixture_list = mixture_list[:-1]
break

return mmd_dist, weight_min, mixture_list
return mmd_dist, weight_list, mixture_list

def _search_by_rkme_spec_single(
self,
@@ -558,15 +570,14 @@ class EasyStatSearcher(BaseSearcher):
logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}")

sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
processed_learnware_list = single_learnware_list[: max_search_num * max_search_num]
if search_method == "auto":
mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto(
learnware_list, user_rkme, max_search_num
processed_learnware_list, user_rkme, max_search_num
)
elif search_method == "greedy":
score_cutoff = sorted_dist_list[0] * 0.05 if \
len(sorted_dist_list) > 0 and self.stat_spec_type == "RKMEImageSpecification" else 0.001
mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy(
learnware_list, user_rkme, max_search_num, score_cutoff=score_cutoff
processed_learnware_list, user_rkme, max_search_num
)
else:
logger.warning("f{search_method} not supported!")
@@ -581,14 +592,23 @@ class EasyStatSearcher(BaseSearcher):
merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist])
sorted_score_list = merge_score_list[:-1]
mixture_score = merge_score_list[-1]
if int(mixture_score * 100) == int(sorted_score_list[0] * 100):
mixture_score = None
mixture_learnware_list = []
logger.info(
f"After search by rkme spec, learnware_list length is {len(learnware_list)}, mixture_learnware_list length is {len(mixture_learnware_list)}"
)

logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}")
# filter learnware with low score
# Filter learnware with low score
sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single(
sorted_score_list, single_learnware_list
)

if len(single_learnware_list) == 1 and sorted_score_list[0] < 0.6:
ratio = 0.6 / sorted_score_list[0]
sorted_score_list[0] = 0.6
mixture_score = min(1, mixture_score * ratio) if mixture_score is not None else None
logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}")

return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list




Loading…
Cancel
Save