Browse Source

[ENH] Implement search_by_rkme_spec

tags/v0.3.2
liuht 2 years ago
parent
commit
13b7eee147
1 changed files with 133 additions and 8 deletions
  1. +133
    -8
      learnware/market/easy.py

+ 133
- 8
learnware/market/easy.py View File

@@ -1,4 +1,5 @@
import os
import torch
import numpy as np
import pandas as pd
from typing import Tuple, Any, List, Union, Dict
@@ -116,31 +117,148 @@ class EasyMarket(BaseMarket):
self.count += 1

return id, True
def _calculate_rkme_spec_mixture_weight(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None
) -> Tuple[List[float], float]:
"""Calculate mixture weight for the learnware_list based on a user's rkme

Parameters
----------
learnware_list : List[Learnware]
A list of existing learnwares
user_rkme : RKMEStatSpecification
User RKME statistical specification
intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None
intermediate_C : np.ndarray, optional
Intermediate inner product vector C, by default None

Returns
-------
Tuple[List[float], float]
The first is the list of mixture weights
The second is the mmd dist between the mixture of learnware rkmes and the user's rkme
"""
learnware_num = len(learnware_list)
RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list]

if type(intermediate_K) == np.ndarray:
K = intermediate_K
else:
K = np.zeros((learnware_num, learnware_num))
for i in range(K.shape[0]):
for j in range(K.shape[1]):
K[i, j] = RKME_list[i].inner_prod(RKME_list[j])

if type(intermediate_C) == np.ndarray:
C = intermediate_C
else:
C = np.zeros((learnware_num, 1))
for i in range(C.shape[0]):
C[i, 0] = user_rkme.inner_prod(RKME_list[i])

K = torch.from_numpy(K).double().to(user_rkme.device)
C = torch.from_numpy(C).double().to(user_rkme.device)

#if nonnegative_beta:
# w = solve_qp(K, C).double().to(Phi_t.device)
#else:
weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C

term1 = user_rkme.eval_Phi(user_rkme)
term2 = weight.T @ C
term3 = weight.T @ K @ weight
score = float(term1 - 2 * term2 + term3)

return weight.detach().cpu().numpy().reshape(-1), score
def _calculate_intermediate_K_and_C(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None
) -> Tuple[np.ndarray, np.ndarray]:
"""Incrementally update the values of intermediate_K and intermediate_C

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares up till now
user_rkme : RKMEStatSpecification
User RKME statistical specification
intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None
intermediate_C : np.ndarray, optional
Intermediate inner product vector C, by default None

Returns
-------
Tuple[np.ndarray, np.ndarray]
The first is the intermediate value of K
The second is the intermediate value of C
"""
num = intermediate_K.shape[0] - 1
RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list]
for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1])
return intermediate_K, intermediate_C

def _search_by_rkme_spec_mixture(self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, search_num: int) -> Tuple[List[float], List[Learnware]]:
"""Get search_num learnwares with their mixture weight from the given learnware_list

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user RKME statistical specification
learnware_num : int
the number of the returned learnwares
User RKME statistical specification
search_num : int
The number of the returned learnwares

Returns
-------
Tuple[List[float], List[Learnware]]
the first is the list of weight
the second is the list of Learnware
the size of both list equals search_num
The first is the list of weight
The second is the list of Learnware
The size of both list equals search_num
"""
pass
learnware_num = len(learnware_list)
_, sorted_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
flag_list = [0 for i in range(learnware_num)]
mixture_list = []
intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1))

for k in range(search_num):
idx_min, score_min = -1, -1
weight_min = None
mixture_list.append(None)

if k != 0:
intermediate_K = np.c_[intermediate_K, np.zeros((k, 1))]
intermediate_K = np.r_[intermediate_K, np.zeros((1, k + 1))]
intermediate_C = np.r_[intermediate_C, np.zeros((1, 1))]
for idx in range(len(sorted_learnware_list)):
if flag_list[idx] == 0:
mixture_list[-1] = sorted_learnware_list[idx]
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(mixture_list, user_rkme, intermediate_K, intermediate_C)
weight, score = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme, intermediate_K, intermediate_C)
if idx_min == -1 or score < score_min:
idx_min, score_min, weight_min = idx, score, weight
flag_list[idx_min] = 1
mixture_list[-1] = sorted_learnware_list[idx_min]
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(mixture_list, user_rkme, intermediate_K, intermediate_C)
return weight_min, mixture_list

def _search_by_rkme_spec_single(self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user RKME statistical specification

@@ -151,7 +269,14 @@ class EasyMarket(BaseMarket):
the second is the list of Learnware
both lists are sorted by mmd dist
"""
pass
RKME_list = [learnware.specification.get_stat_spec_by_name('RKME') for learnware in learnware_list]
mmd_dist_list = []
for RKME in RKME_list:
mmd_dist = RKME.dist(user_rkme)
mmd_dist_list.append(mmd_dist)
sorted_dist_list, sorted_learnware_list = (list(t) for t in zip(*sorted(zip(mmd_dist_list, learnware_list))))
return sorted_dist_list, sorted_learnware_list
def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]:
def search_by_semantic_spec():


Loading…
Cancel
Save