|
|
|
@@ -153,7 +153,7 @@ class EasyMarket(BaseMarket): |
|
|
|
# else: |
|
|
|
weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C |
|
|
|
|
|
|
|
term1 = user_rkme.eval_Phi(user_rkme) |
|
|
|
term1 = user_rkme.inner_prod(user_rkme) |
|
|
|
term2 = weight.T @ C |
|
|
|
term3 = weight.T @ K @ weight |
|
|
|
score = float(term1 - 2 * term2 + term3) |
|
|
|
@@ -274,7 +274,10 @@ class EasyMarket(BaseMarket): |
|
|
|
for RKME in RKME_list: |
|
|
|
mmd_dist = RKME.dist(user_rkme) |
|
|
|
mmd_dist_list.append(mmd_dist) |
|
|
|
sorted_dist_list, sorted_learnware_list = (list(t) for t in zip(*sorted(zip(mmd_dist_list, learnware_list)))) |
|
|
|
|
|
|
|
sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) |
|
|
|
sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list] |
|
|
|
sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list] |
|
|
|
|
|
|
|
return sorted_dist_list, sorted_learnware_list |
|
|
|
|
|
|
|
@@ -304,9 +307,32 @@ class EasyMarket(BaseMarket): |
|
|
|
learnware_list = [self.learnware_list[key] for key in self.learnware_list] |
|
|
|
return learnware_list |
|
|
|
|
|
|
|
def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: |
|
|
|
def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]: |
|
|
|
"""Search learnwares based on user_info |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
user_info : BaseUserInfo |
|
|
|
user_info contains semantic_spec and stat_info |
|
|
|
search_num : int |
|
|
|
The number of the returned learnwares |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple[List[float], List[Learnware], List[float], List[Learnware]] |
|
|
|
the first is the sorted list of rkme dist |
|
|
|
the second is the sorted list of Learnware (single) by the rkme dist |
|
|
|
the third is the list of Learnware (mixture), the size is search_num |
|
|
|
""" |
|
|
|
learnware_list = self._search_by_semantic_spec(user_info) |
|
|
|
return learnware_list |
|
|
|
|
|
|
|
if "RKME" not in user_info.stat_info: |
|
|
|
return None, learnware_list, None |
|
|
|
else: |
|
|
|
user_rkme = user_info.stat_info["RKME"] |
|
|
|
sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) |
|
|
|
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num) |
|
|
|
return sorted_dist_list, single_learnware_list, mixture_learnware_list |
|
|
|
|
|
|
|
def delete_learnware(self, id: str) -> bool: |
|
|
|
if not id in self.learnware_list: |
|
|
|
|