diff --git a/learnware/learnware/base.py b/learnware/learnware/base.py index 4500c71..8578f91 100644 --- a/learnware/learnware/base.py +++ b/learnware/learnware/base.py @@ -29,7 +29,7 @@ class Learnware: Raises ------ TypeError - The type of model must be dict or BaseModel, else raise error + The type of model must be str or BaseModel, else raise error """ if isinstance(model, BaseModel): return model @@ -42,7 +42,7 @@ class Learnware: model_module = get_module_by_module_path(model_dict["module_path"]) return getattr(model_module, model_dict["class_name"])() else: - raise TypeError("model must be BaseModel or dict") + raise TypeError("model must be BaseModel or str") def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 8fa78ae..a752081 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -153,7 +153,7 @@ class EasyMarket(BaseMarket): # else: weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C - term1 = user_rkme.eval_Phi(user_rkme) + term1 = user_rkme.inner_prod(user_rkme) term2 = weight.T @ C term3 = weight.T @ K @ weight score = float(term1 - 2 * term2 + term3) @@ -274,7 +274,10 @@ class EasyMarket(BaseMarket): for RKME in RKME_list: mmd_dist = RKME.dist(user_rkme) mmd_dist_list.append(mmd_dist) - sorted_dist_list, sorted_learnware_list = (list(t) for t in zip(*sorted(zip(mmd_dist_list, learnware_list)))) + + sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) + sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list] + sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list] return sorted_dist_list, sorted_learnware_list @@ -304,9 +307,32 @@ class EasyMarket(BaseMarket): learnware_list = [self.learnware_list[key] for key in self.learnware_list] return learnware_list - def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: + def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]: + """Search learnwares based on user_info + + Parameters + ---------- + user_info : BaseUserInfo + user_info contains semantic_spec and stat_info + search_num : int + The number of the returned learnwares + + Returns + ------- + Tuple[List[float], List[Learnware], List[float], List[Learnware]] + the first is the sorted list of rkme dist + the second is the sorted list of Learnware (single) by the rkme dist + the third is the list of Learnware (mixture), the size is search_num + """ learnware_list = self._search_by_semantic_spec(user_info) - return learnware_list + + if "RKME" not in user_info.stat_info: + return None, learnware_list, None + else: + user_rkme = user_info.stat_info["RKME"] + sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) + weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num) + return sorted_dist_list, single_learnware_list, mixture_learnware_list def delete_learnware(self, id: str) -> bool: if not id in self.learnware_list: