diff --git a/learnware/market/easy.py b/learnware/market/easy.py index a687bab..e00d54f 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -1,4 +1,5 @@ import os +import torch import numpy as np import pandas as pd from typing import Tuple, Any, List, Union, Dict @@ -116,31 +117,148 @@ class EasyMarket(BaseMarket): self.count += 1 return id, True + + def _calculate_rkme_spec_mixture_weight( + self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None + ) -> Tuple[List[float], float]: + """Calculate mixture weight for the learnware_list based on a user's rkme + + Parameters + ---------- + learnware_list : List[Learnware] + A list of existing learnwares + user_rkme : RKMEStatSpecification + User RKME statistical specification + intermediate_K : np.ndarray, optional + Intermediate kernel matrix K, by default None + intermediate_C : np.ndarray, optional + Intermediate inner product vector C, by default None + + Returns + ------- + Tuple[List[float], float] + The first is the list of mixture weights + The second is the mmd dist between the mixture of learnware rkmes and the user's rkme + """ + learnware_num = len(learnware_list) + RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list] + + if type(intermediate_K) == np.ndarray: + K = intermediate_K + else: + K = np.zeros((learnware_num, learnware_num)) + for i in range(K.shape[0]): + for j in range(K.shape[1]): + K[i, j] = RKME_list[i].inner_prod(RKME_list[j]) + + if type(intermediate_C) == np.ndarray: + C = intermediate_C + else: + C = np.zeros((learnware_num, 1)) + for i in range(C.shape[0]): + C[i, 0] = user_rkme.inner_prod(RKME_list[i]) + + K = torch.from_numpy(K).double().to(user_rkme.device) + C = torch.from_numpy(C).double().to(user_rkme.device) + + #if nonnegative_beta: + # w = solve_qp(K, C).double().to(Phi_t.device) + #else: + weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C + + term1 = user_rkme.eval_Phi(user_rkme) + term2 = weight.T @ C + term3 = weight.T @ K @ weight + score = float(term1 - 2 * term2 + term3) + + return weight.detach().cpu().numpy().reshape(-1), score + + def _calculate_intermediate_K_and_C( + self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Incrementally update the values of intermediate_K and intermediate_C + + Parameters + ---------- + learnware_list : List[Learnware] + The list of learnwares up till now + user_rkme : RKMEStatSpecification + User RKME statistical specification + intermediate_K : np.ndarray, optional + Intermediate kernel matrix K, by default None + intermediate_C : np.ndarray, optional + Intermediate inner product vector C, by default None + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The first is the intermediate value of K + The second is the intermediate value of C + """ + num = intermediate_K.shape[0] - 1 + RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list] + for i in range(intermediate_K.shape[0]): + intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) + intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) + return intermediate_K, intermediate_C def _search_by_rkme_spec_mixture(self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, search_num: int) -> Tuple[List[float], List[Learnware]]: """Get search_num learnwares with their mixture weight from the given learnware_list Parameters ---------- + learnware_list : List[Learnware] + The list of learnwares whose mixture approximates the user's rkme user_rkme : RKMEStatSpecification - user RKME statistical specification - learnware_num : int - the number of the returned learnwares + User RKME statistical specification + search_num : int + The number of the returned learnwares Returns ------- Tuple[List[float], List[Learnware]] - the first is the list of weight - the second is the list of Learnware - the size of both list equals search_num + The first is the list of weight + The second is the list of Learnware + The size of both list equals search_num """ - pass + learnware_num = len(learnware_list) + _, sorted_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) + flag_list = [0 for i in range(learnware_num)] + mixture_list = [] + intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) + + for k in range(search_num): + idx_min, score_min = -1, -1 + weight_min = None + mixture_list.append(None) + + if k != 0: + intermediate_K = np.c_[intermediate_K, np.zeros((k, 1))] + intermediate_K = np.r_[intermediate_K, np.zeros((1, k + 1))] + intermediate_C = np.r_[intermediate_C, np.zeros((1, 1))] + + for idx in range(len(sorted_learnware_list)): + if flag_list[idx] == 0: + mixture_list[-1] = sorted_learnware_list[idx] + intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(mixture_list, user_rkme, intermediate_K, intermediate_C) + weight, score = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme, intermediate_K, intermediate_C) + if idx_min == -1 or score < score_min: + idx_min, score_min, weight_min = idx, score, weight + + flag_list[idx_min] = 1 + mixture_list[-1] = sorted_learnware_list[idx_min] + intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(mixture_list, user_rkme, intermediate_K, intermediate_C) + + return weight_min, mixture_list + def _search_by_rkme_spec_single(self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme Parameters ---------- + learnware_list : List[Learnware] + The list of learnwares whose mixture approximates the user's rkme user_rkme : RKMEStatSpecification user RKME statistical specification @@ -151,7 +269,14 @@ class EasyMarket(BaseMarket): the second is the list of Learnware both lists are sorted by mmd dist """ - pass + RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list] + mmd_dist_list = [] + for RKME in RKME_list: + mmd_dist = RKME.dist(user_rkme) + mmd_dist_list.append(mmd_dist) + sorted_dist_list, sorted_learnware_list = (list(t) for t in zip(*sorted(zip(mmd_dist_list, learnware_list)))) + + return sorted_dist_list, sorted_learnware_list def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: def search_by_semantic_spec():