From 988189903f2aeb0f0425300820751bbc900861ed Mon Sep 17 00:00:00 2001 From: liuht Date: Thu, 13 Apr 2023 10:18:02 +0800 Subject: [PATCH 1/7] [FIX] Fix "fintune" to "finetune" --- examples/example_market_db/example_init.py | 2 +- examples/examples2/svm/__init__.py | 2 +- examples/learnware_config/svm/__init__.py | 2 +- examples/workflow_by_code/example_init.py | 2 +- learnware/model/base.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/example_market_db/example_init.py b/examples/example_market_db/example_init.py index 5ad99bb..82d0cb4 100644 --- a/examples/example_market_db/example_init.py +++ b/examples/example_market_db/example_init.py @@ -15,5 +15,5 @@ class SVM(BaseModel): def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - def fintune(self, X: np.ndarray, y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): pass diff --git a/examples/examples2/svm/__init__.py b/examples/examples2/svm/__init__.py index 5ad99bb..82d0cb4 100644 --- a/examples/examples2/svm/__init__.py +++ b/examples/examples2/svm/__init__.py @@ -15,5 +15,5 @@ class SVM(BaseModel): def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - def fintune(self, X: np.ndarray, y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): pass diff --git a/examples/learnware_config/svm/__init__.py b/examples/learnware_config/svm/__init__.py index 5ad99bb..82d0cb4 100644 --- a/examples/learnware_config/svm/__init__.py +++ b/examples/learnware_config/svm/__init__.py @@ -15,5 +15,5 @@ class SVM(BaseModel): def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - def fintune(self, X: np.ndarray, y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): pass diff --git a/examples/workflow_by_code/example_init.py b/examples/workflow_by_code/example_init.py index 5ad99bb..82d0cb4 100644 --- a/examples/workflow_by_code/example_init.py +++ b/examples/workflow_by_code/example_init.py @@ -15,5 +15,5 @@ class SVM(BaseModel): def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - def fintune(self, X: np.ndarray, y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): pass diff --git a/learnware/model/base.py b/learnware/model/base.py index 82bc15f..5c32f11 100644 --- a/learnware/model/base.py +++ b/learnware/model/base.py @@ -11,5 +11,5 @@ class BaseModel: def predict(self, X: np.ndarray) -> np.ndarray: pass - def fintune(self, X: np.ndarray, y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): pass From 068c6ce3581c8bc173e9f31a0e02f80c9bad354f Mon Sep 17 00:00:00 2001 From: xiey Date: Thu, 13 Apr 2023 16:42:12 +0800 Subject: [PATCH 2/7] [MNT] Change details in _search_by_semantic_spec --- learnware/market/easy.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 79aaa3d..547fe01 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -335,8 +335,8 @@ class EasyMarket(BaseMarket): return sorted_dist_list, sorted_learnware_list - def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: - def match_semantic_tags(semantic_spec1, semantic_spec2): + def _search_by_semantic_spec(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: + def match_semantic_spec(semantic_spec1, semantic_spec2): if semantic_spec1.keys() != semantic_spec2.keys(): # raise Exception("semantic_spec key error") logger.warning("semantic_spec key error!") @@ -346,18 +346,20 @@ class EasyMarket(BaseMarket): continue if len(semantic_spec2[key]["Values"]) == 0: continue + v1 = semantic_spec1[key]["Values"] + v2 = semantic_spec2[key]["Values"] if semantic_spec1[key]["Type"] == "Class": - if isinstance(semantic_spec1[key]["Values"], list): - semantic_spec1[key]["Values"] = semantic_spec1[key]["Values"][0] - if isinstance(semantic_spec2[key]["Values"], list): - semantic_spec2[key]["Values"] = semantic_spec2[key]["Values"][0] - if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]: + if isinstance(v1, list): + v1 = v1[0] + if isinstance(v2, list): + v2 = v2[0] + if v1 != v2: return False elif semantic_spec1[key]["Type"] == "Tag": - if not (set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"])): + if not (set(v1) & set(v2)): return False elif semantic_spec1[key]["Type"] == "Name": - if semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"]: + if v2 not in v1 and v2 not in semantic_spec1["Description"]["Values"]: return False return True @@ -365,7 +367,7 @@ class EasyMarket(BaseMarket): for learnware in learnware_list: learnware_semantic_spec = learnware.get_specification().get_semantic_spec() user_semantic_spec = user_info.get_semantic_spec() - if match_semantic_tags(learnware_semantic_spec, user_semantic_spec): + if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): match_learnwares.append(learnware) return match_learnwares @@ -389,7 +391,7 @@ class EasyMarket(BaseMarket): the third is the list of Learnware (mixture), the size is search_num """ learnware_list = [self.learnware_list[key] for key in self.learnware_list] - learnware_list = self._search_by_semantic_tags(learnware_list, user_info) + learnware_list = self._search_by_semantic_spec(learnware_list, user_info) # learnware_list = list(set(learnware_list_tags + learnware_list_description)) if "RKMEStatSpecification" not in user_info.stat_info: From faf85607715f2c1dc7a3a34fb2fa6ce5a47a7341 Mon Sep 17 00:00:00 2001 From: xiey Date: Thu, 13 Apr 2023 16:43:18 +0800 Subject: [PATCH 3/7] [MNT] black -l --- examples/example_market_db/example_db.py | 5 +---- examples/workflow_by_code/main.py | 20 ++++--------------- learnware/config.py | 25 +++++------------------- learnware/learnware/__init__.py | 5 +---- learnware/market/easy.py | 5 +---- learnware/specification/rkme.py | 6 ++---- 6 files changed, 14 insertions(+), 52 deletions(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index ce8aec3..474b05e 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -40,10 +40,7 @@ semantic_specs = [ user_senmantic = { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 54fb60b..b1d872a 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -16,10 +16,7 @@ curr_root = os.path.dirname(os.path.abspath(__file__)) semantic_specs = [ { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -27,10 +24,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -38,10 +32,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -51,10 +42,7 @@ semantic_specs = [ user_senmantic = { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, diff --git a/learnware/config.py b/learnware/config.py index 997c732..3d44a30 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -66,10 +66,7 @@ os.makedirs(LEARNWARE_FOLDER_POOL_PATH, exist_ok=True) os.makedirs(DATABASE_PATH, exist_ok=True) semantic_config = { - "Data": { - "Values": ["Tabular", "Image", "Video", "Text", "Audio"], - "Type": "Class", - }, # Choose only one class + "Data": {"Values": ["Tabular", "Image", "Video", "Text", "Audio"], "Type": "Class",}, # Choose only one class "Task": { "Values": [ "Classification", @@ -82,10 +79,7 @@ semantic_config = { ], "Type": "Class", # Choose only one class }, - "Device": { - "Values": ["CPU", "GPU"], - "Type": "Tag", - }, # Choose one or more tags + "Device": {"Values": ["CPU", "GPU"], "Type": "Tag",}, # Choose one or more tags "Scenario": { "Values": [ "Business", @@ -105,14 +99,8 @@ semantic_config = { ], "Type": "Tag", # Choose one or more tags }, - "Description": { - "Values": None, - "Type": "Description", - }, - "Name": { - "Values": None, - "Type": "Name", - }, + "Description": {"Values": None, "Type": "Description",}, + "Name": {"Values": None, "Type": "Name",}, } _DEFAULT_CONFIG = { @@ -123,10 +111,7 @@ _DEFAULT_CONFIG = { "learnware_pool_path": LEARNWARE_POOL_PATH, "learnware_zip_pool_path": LEARNWARE_ZIP_POOL_PATH, "learnware_folder_pool_path": LEARNWARE_FOLDER_POOL_PATH, - "learnware_folder_config": { - "yaml_file": "learnware.yaml", - "module_file": "__init__.py", - }, + "learnware_folder_config": {"yaml_file": "learnware.yaml", "module_file": "__init__.py",}, "database_path": DATABASE_PATH, } diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index aa24867..899abd9 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -29,10 +29,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: The contructed learnware object, return None if build failed """ learnware_config = { - "model": { - "class_name": "Model", - "kwargs": {}, - }, + "model": {"class_name": "Model", "kwargs": {},}, "stat_specifications": [ { "module_path": "learnware.specification", diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 547fe01..c31abaa 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -119,10 +119,7 @@ class EasyMarket(BaseMarket): self.learnware_folder_list[id] = target_folder_dir self.count += 1 add_learnware_to_db( - id, - semantic_spec=semantic_spec, - zip_path=target_zip_dir, - folder_path=target_folder_dir, + id, semantic_spec=semantic_spec, zip_path=target_zip_dir, folder_path=target_folder_dir, ) return id, True diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index 9668396..568c410 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -255,9 +255,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), + rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), separators=(",", ":"), ) def load(self, filepath: str) -> bool: @@ -345,7 +343,7 @@ def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: """ x1 = x1.double() x2 = x2.double() - X12norm = torch.sum(x1**2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2**2, 1, keepdim=True).T + X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T return torch.exp(-X12norm * gamma) From 7962a49e696245cbe8decaf62cae9376392fe365 Mon Sep 17 00:00:00 2001 From: xiey Date: Thu, 13 Apr 2023 19:46:32 +0800 Subject: [PATCH 4/7] [ENH] Add learnware semantics spec check --- examples/example_market_db/example_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 474b05e..771cf10 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -44,7 +44,7 @@ user_senmantic = { "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, - "Name": {"Values": "learnware_4", "Type": "Name"}, + "Name": {"Values": "learnware", "Type": "Name"}, } From f8f65e99d3ab717abb1ed0c59463f4ef1c427a70 Mon Sep 17 00:00:00 2001 From: liuht Date: Fri, 14 Apr 2023 11:19:19 +0800 Subject: [PATCH 5/7] [MNT] Add auto rkme_spec matching instead of top k --- learnware/market/easy.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index c31abaa..cd1367c 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -238,7 +238,7 @@ class EasyMarket(BaseMarket): return intermediate_K, intermediate_C def _search_by_rkme_spec_mixture( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, search_num: int + self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int = 5, score_cutoff: float = 0.1 ) -> Tuple[List[float], List[Learnware]]: """Get search_num learnwares with their mixture weight from the given learnware_list @@ -248,8 +248,10 @@ class EasyMarket(BaseMarket): The list of learnwares whose mixture approximates the user's rkme user_rkme : RKMEStatSpecification User RKME statistical specification - search_num : int - The number of the returned learnwares + max_search_num : int + The maximum number of the returned learnwares + score_cutof: float + The minimum mmd dist as threshold to stop further rkme_spec matching Returns ------- @@ -262,14 +264,14 @@ class EasyMarket(BaseMarket): if learnware_num == 0: return [], [] if learnware_num < search_num: - logger.warning("Available Learnware num less than search_num") + logger.warning("Available Learnware num less than search_num!") search_num = learnware_num - flag_list = [0 for i in range(learnware_num)] + flag_list = [0 for _ in range(learnware_num)] mixture_list = [] intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) - for k in range(search_num): + for k in range(max_search_num): idx_min, score_min = -1, -1 weight_min = None mixture_list.append(None) @@ -291,11 +293,14 @@ class EasyMarket(BaseMarket): if idx_min == -1 or score < score_min: idx_min, score_min, weight_min = idx, score, weight - flag_list[idx_min] = 1 - mixture_list[-1] = learnware_list[idx_min] - intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( - mixture_list, user_rkme, intermediate_K, intermediate_C - ) + if score_min >= score_cutoff: + flag_list[idx_min] = 1 + mixture_list[-1] = learnware_list[idx_min] + intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( + mixture_list, user_rkme, intermediate_K, intermediate_C + ) + else: + break return weight_min, mixture_list From 5397732f5b9411fea46cdb30eed158eaac9457fd Mon Sep 17 00:00:00 2001 From: liuht Date: Fri, 14 Apr 2023 11:20:09 +0800 Subject: [PATCH 6/7] [MNT] Modify setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 8ef3ace..f32f0c3 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,8 @@ REQUIRED = [ "joblib>=1.2.0", "pyyaml>=6.0", "fire>=0.5.0", + "sklearn>=1.0.2", + "lightgbm>=3.3.5" ] here = os.path.abspath(os.path.dirname(__file__)) From 44659745e9c389ba152db1b379dec1e285fd6b0f Mon Sep 17 00:00:00 2001 From: liuht Date: Fri, 14 Apr 2023 11:22:15 +0800 Subject: [PATCH 7/7] [ENH] Add BaseReuse using job selector --- learnware/learnware/__init__.py | 1 + learnware/learnware/reuse.py | 189 +++++++++++++++++++++++++++++++- learnware/specification/rkme.py | 86 ++++++++++++++- 3 files changed, 273 insertions(+), 3 deletions(-) diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index 899abd9..f52630e 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -2,6 +2,7 @@ import os import copy from .base import Learnware +from .reuse import BaseReuse from .utils import get_stat_spec_from_config, get_model_from_config from ..specification import Specification from ..utils import read_yaml_to_dict diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index 18959a8..8b89964 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -1,3 +1,188 @@ +import numpy as np +from typing import Tuple, Any, List, Union, Dict +from cvxopt import matrix, solvers +from lightgbm import LGBMClassifier +from sklearn.metrics import accuracy_score + +from learnware.learnware import Learnware +import learnware.specification as specification +from ..specification import RKMEStatSpecification +from ..logger import get_module_logger + +logger = get_module_logger("BaseReuse") + class BaseReuse: - def __init__(self): - pass + """Baseline Multiple Learnware Reuse uing Job Selector Method""" + + def __init__(self, learnware_list: List[Learnware], herding_num: int = 100): + self.learnware_list = learnware_list + self.herding_num = herding_num + + def predict(self, user_data: np.ndarray) -> np.ndarray: + """Give prediction for user data using baseline job-selector method + + Parameters + ---------- + user_data : np.ndarray + User's labeled raw data. + + Returns + ------- + np.ndarray + Prediction given by job-selector method + """ + _, select_result = self.job_selector(user_data) + selector_pred_y = np.zeros(len(user_data.shape[0])) + + for idx in range(len(self.learnware_list)): + data_idx_list = np.where(select_result == idx)[0] + if len(data_idx_list) > 0: + selector_pred_y[data_idx_list] = self.learnware_list[idx].predict(data_idx_list) + + return selector_pred_y + + def job_selector(self, user_data: np.ndarray): + """Train job selector based on user's data, which predicts which learnware in the pool should be selected + + Parameters + ---------- + user_data : np.ndarray + _description_ + """ + learnware_rkme_spec_list = [ + learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in self.learnware_list + ] + task_matrix = np.zeros((len(learnware_rkme_spec_list), len(learnware_rkme_spec_list))) + + for i in range(len(self.learnware_list)): + task_rkme1 = learnware_rkme_spec_list[i] + for j in range(i, len(self.learnware_list)): + task_rkme2 = learnware_rkme_spec_list[j] + task_matrix[i][j] = task_matrix[j][i] = task_rkme1.inner_prd(task_rkme2) + + task_mixture_weight = self._calculate_rkme_spec_mixture_weight(user_data, learnware_rkme_spec_list, task_matrix) + + herding_X, train_herding_X, val_herding_X = None, None, None + herding_y, train_herding_y, val_herding_y = [], [], [] + for i in range(len(self.learnware_list)): + task_spec = learnware_rkme_spec_list[i] + task_herding_num = max(5, int(self.herding_num * task_mixture_weight[i])) + task_val_num = task_herding_num // 5 + + herding_X_i = task_spec.herding(task_herding_num).detach().cpu().numpy() + train_X_i = herding_X_i[:-task_val_num] + val_X_i = herding_X_i[task_val_num:] + + herding_X = herding_X_i if herding_X is None else np.concatenate((herding_X, herding_X_i), axis=0) + train_herding_X = train_X_i if train_herding_X is None else np.concatenate((train_herding_X, train_X_i), axis=0) + val_herding_X = val_X_i if val_herding_X is None else np.concatenate((val_herding_X, val_X_i), axis=0) + + herding_y += [i] * task_herding_num + train_herding_y += [i] * (task_herding_num - task_val_num) + val_herding_y += [i] * task_val_num + + herding_y = np.array(herding_y) + train_herding_y = np.array(train_herding_y) + val_herding_y = np.array(val_herding_y) + + # use herding samples to train a job selector + job_selector = self._selector_grid_search(herding_X, herding_y, train_herding_X, train_herding_y, val_herding_X, val_herding_y, len(self.learnware_list)) + job_select_result = np.array(job_selector.predict(user_data)) + + return job_selector, job_select_result + + + def _calculate_rkme_spec_mixture_weight( + self, user_data: np.ndarray, task_rkme_list: List[RKMEStatSpecification], task_rkme_matrix: np.ndarray + ) -> List[float]: + """_summary_ + + Parameters + ---------- + user_data : np.ndarray + _description_ + task_rkme_list : List[RKMEStatSpecification] + _description_ + task_rkme_matrix : np.ndarray + _description_ + """ + task_num = len(task_rkme_list) + user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) + K = task_rkme_matrix + v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) + + P = matrix(K) + q = matrix(-v) + G = matrix(-np.eye(task_num)) + h = matrix(np.zeros((task_num, 1))) + A = matrix(np.ones((1, task_num))) + b = matrix(np.ones((1, 1))) + solvers.options["show_progress"] = False + sol = solvers.qp(P, q, G, h, A, b) + task_mixture_weight = np.array(sol["x"]).reshape(-1) + + return task_mixture_weight + + def _selector_grid_search( + org_train_x: np.ndarray, org_train_y: np.ndarray, train_x: np.ndarray, train_y: np.ndarray, val_x: np.ndarray, val_y: np.ndarray, num_class:int + ) -> LGBMClassifier: + """Train a LGBMClassifier as job selector using the herding data as training instances. + + Parameters + ---------- + org_train_x : np.ndarray + The original herding features. + org_train_y : np.ndarray + The original hearding labels(which are learnware indexes). + train_x : np.ndarray + Herding features used for training. + train_y : np.ndarray + Herding labels used for training. + val_x : np.ndarray + Herding features used for validation. + val_y : np.ndarray + Herding labels used for validation. + num_class : int + Total number of classes for the job selector(which is exactly the total number of learnwares to be reused). + + Returns + ------- + LGBMClassifier + The job selector model. + """ + score_best = -1 + learning_rate = [0.01] + max_depth = [66] + params = (0, 0) + + for lr in learning_rate: + for md in max_depth: + model = LGBMClassifier( + max_depth=md, + learning_rate=lr, + n_estimators=2000, + objective="multiclass", + num_class=num_class, + booster="gbtree", + seed=0, + ) + model.fit(train_x, train_y, eval_set=[(val_x, val_y)], verbose=100, early_stopping_rounds=300) + pred_y = model.predict(org_train_x) + score = accuracy_score(pred_y, org_train_y) + + if score > score_best: + score_best = score + params = (lr, md) + + model = LGBMClassifier( + max_depth=params[1], + learning_rate=params[0], + n_estimators=2000, + objective="multiclass", + num_class=num_class, + booster="gbtree", + seed=0, + ) + model.fit(org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=100, early_stopping_rounds=300) + + return model \ No newline at end of file diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index 568c410..c187d1a 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -10,10 +10,13 @@ import codecs import random import numpy as np from cvxopt import solvers, matrix +from collections import Counter from typing import Tuple, Any, List, Union, Dict from .base import BaseStatSpecification +from ..logger import get_module_logger +logger = get_module_logger("rkme") class RKMEStatSpecification(BaseStatSpecification): """Reduced-set Kernel Mean Embedding (RKME) Specification""" @@ -196,6 +199,59 @@ class RKMEStatSpecification(BaseStatSpecification): Z = Z - step_size * grad_Z self.z = Z + + def _inner_prod_with_X(self, X: Any) -> float: + """Compute the inner product between RKME specification and X + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in np.ndarray format or torch.tensor format. + + Returns + ------- + float + The inner product between RKME specification and X + """ + beta = self.beta.reshape(1, -1).double().to(self.device) + Z = self.z.double().to(self.device) + if not torch.is_tensor(X): + X = torch.from_numpy(X) + X = X.to(self.device).double() + + v = torch_rbf_kernel(Z, X, self.gamma) * beta.double + v = torch.sum(v, axis = 0) + return v.detach().cpu().numpy() + + def _sampling_candidates(self, N: int) -> np.ndarray: + """Generate a large set of candidates as preparation for herding + + Parameters + ---------- + N : int + The number of herding candidates. + + Returns + ------- + np.ndarray + The herding candidates. + """ + beta = self.beta + beta[beta < 0] = 0 # currently we cannot use negative weight + beta = beta / torch.sum(beta) + sample_assign = torch.multinomial(beta, N, replacement=True) + + sample_list = [] + for i, n in Counter(np.array(sample_assign.cpu())).items(): + for _ in range(n): + sample_list.append(torch.normal(mean=self.z[i], std=0.25).reshape(1, -1)) + if len(sample_list) > 1: + return torch.cat(sample_list, axis=0) + elif len(sample_list) == 1: + return sample_list[0] + else: + logger.warning("Not enough candidates for herding!") + def inner_prod(self, Phi2: RKMEStatSpecification) -> float: """Compute the inner product between two RKME specifications @@ -226,7 +282,7 @@ class RKMEStatSpecification(BaseStatSpecification): Phi2 : RKMEStatSpecification The other RKME specification. omit_term1 : bool, optional - True if the inner product of self with itself can be omitted, by default False + True if the inner product of self with itself can be omitted, by default False. """ if omit_term1: term1 = 0 @@ -236,6 +292,34 @@ class RKMEStatSpecification(BaseStatSpecification): term3 = Phi2.inner_prod(Phi2) return float(term1 - 2 * term2 + term3) + + def herding(self, T: int) -> np.ndarray: + """Iteratively sample examples from an unknown distribution with the help of its RKME specification + + Parameters + ---------- + T : int + Total iteration number for sampling. + + Returns + ------- + np.ndarray + A collection of examples which approximate the unknown distribution. + """ + Nstart = 100 * T + Xstart = self._sampling_candidates(Nstart).to(self.device) + D = self.z[0].shape[0] + S = torch.zeros((T, D)).to(self.device) + fsX = torch.from_numpy(self._inner_prod_with_X(Xstart)).to(self.device) + fsS = torch.zeros(Nstart).to(self.device) + for i in range(T): + if i > 0: + fsS = torch.sum(torch_rbf_kernel(S[:i, :], Xstart, self.gamma), axis=0) + fs = (i + 1) * fsX - fsS + idx = torch.argmax(fs) + S[i, :] = Xstart[idx, :] + + return S def save(self, filepath: str): """Save the computed RKME specification to a specified path in JSON format.