Browse Source

[MNT] Add another search multiple learnware method

tags/v0.3.2
liuht 2 years ago
parent
commit
44c63fc2a8
2 changed files with 51 additions and 2 deletions
  1. +50
    -2
      learnware/market/easy.py
  2. +1
    -0
      setup.py

+ 50
- 2
learnware/market/easy.py View File

@@ -276,6 +276,55 @@ class EasyMarket(BaseMarket):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1])
return intermediate_K, intermediate_C
def _search_by_rkme_spec_mixture_auto(
self,
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
max_search_num: int,
weight_cutoff: float = 0.9
) -> Tuple[List[float], List[Learnware]]:
"""Select learnwares based on a total mixture ratio, then recalculate their mixture weights

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
User RKME statistical specification
max_search_num : int
The maximum number of the returned learnwares
weight_cutoff : float, optional
The ratio for selecting out the mose relevant learnwares, by default 0.9

Returns
-------
Tuple[List[float], List[Learnware]]
The first is the list of weight
The second is the list of Learnware
"""
learnware_num = len(learnware_list)
if learnware_num == 0:
return [], []
if learnware_num < max_search_num:
logger.warning("Available Learnware num less than search_num!")
max_search_num = learnware_num

weight, _ = self._calculate_rkme_spec_mixture_weight(learnware_list, user_rkme)
sort_by_weight_idx_list = sorted(range(learnware_num), key=lambda k: weight[k])
weight_sum = 0
mixture_list = []
for idx in sort_by_weight_idx_list:
weight_sum += sort_by_weight_idx_list[idx]
if weight_sum <= weight_cutoff:
mixture_list.append(learnware_list[idx])

if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num]
mixture_weight, _ = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme)
return mixture_weight, mixture_list

def _search_by_rkme_spec_mixture(
self,
@@ -284,7 +333,7 @@ class EasyMarket(BaseMarket):
max_search_num: int,
score_cutoff: float = 0.01,
) -> Tuple[List[float], List[Learnware]]:
"""Get learnwares with their mixture weight from the given learnware_list
"""Greedily match learnwares such that their mixture become more and more closer to user's rkme

Parameters
----------
@@ -338,7 +387,6 @@ class EasyMarket(BaseMarket):

mixture_list[-1] = learnware_list[idx_min]
if score_min < score_cutoff:
print(score_min)
break
else:
flag_list[idx_min] = 1


+ 1
- 0
setup.py View File

@@ -47,6 +47,7 @@ REQUIRED = [
"fire>=0.3.1",
"lightgbm>=3.3.0",
"psutil>=5.9.4",
"torchvision>=0.15.1"
]

here = os.path.abspath(os.path.dirname(__file__))


Loading…
Cancel
Save