Browse Source

[MNT] Modify the pipeline

tags/v0.3.2
Gene 3 years ago
parent
commit
de97a6624b
2 changed files with 32 additions and 6 deletions
  1. +2
    -2
      learnware/learnware/base.py
  2. +30
    -4
      learnware/market/easy.py

+ 2
- 2
learnware/learnware/base.py View File

@@ -29,7 +29,7 @@ class Learnware:
Raises
------
TypeError
The type of model must be dict or BaseModel, else raise error
The type of model must be str or BaseModel, else raise error
"""
if isinstance(model, BaseModel):
return model
@@ -42,7 +42,7 @@ class Learnware:
model_module = get_module_by_module_path(model_dict["module_path"])
return getattr(model_module, model_dict["class_name"])()
else:
raise TypeError("model must be BaseModel or dict")
raise TypeError("model must be BaseModel or str")

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)


+ 30
- 4
learnware/market/easy.py View File

@@ -153,7 +153,7 @@ class EasyMarket(BaseMarket):
# else:
weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C

term1 = user_rkme.eval_Phi(user_rkme)
term1 = user_rkme.inner_prod(user_rkme)
term2 = weight.T @ C
term3 = weight.T @ K @ weight
score = float(term1 - 2 * term2 + term3)
@@ -274,7 +274,10 @@ class EasyMarket(BaseMarket):
for RKME in RKME_list:
mmd_dist = RKME.dist(user_rkme)
mmd_dist_list.append(mmd_dist)
sorted_dist_list, sorted_learnware_list = (list(t) for t in zip(*sorted(zip(mmd_dist_list, learnware_list))))

sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k])
sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list]
sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list]

return sorted_dist_list, sorted_learnware_list

@@ -304,9 +307,32 @@ class EasyMarket(BaseMarket):
learnware_list = [self.learnware_list[key] for key in self.learnware_list]
return learnware_list
def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]:
def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]:
"""Search learnwares based on user_info

Parameters
----------
user_info : BaseUserInfo
user_info contains semantic_spec and stat_info
search_num : int
The number of the returned learnwares

Returns
-------
Tuple[List[float], List[Learnware], List[float], List[Learnware]]
the first is the sorted list of rkme dist
the second is the sorted list of Learnware (single) by the rkme dist
the third is the list of Learnware (mixture), the size is search_num
"""
learnware_list = self._search_by_semantic_spec(user_info)
return learnware_list

if "RKME" not in user_info.stat_info:
return None, learnware_list, None
else:
user_rkme = user_info.stat_info["RKME"]
sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num)
return sorted_dist_list, single_learnware_list, mixture_learnware_list

def delete_learnware(self, id: str) -> bool:
if not id in self.learnware_list:


Loading…
Cancel
Save