From 8ae362c04f604e3f9bd223046aed73b40cc0e4bd Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 4 Dec 2023 22:54:16 +0800 Subject: [PATCH] [MNT] del useless file --- .../example_learnwares/config.py | 100 ----- .../example_learnwares/learnware.yaml | 8 - .../example_learnwares/model0.py | 16 - .../example_learnwares/model1.py | 16 - .../example_learnwares/requirements.txt | 1 - tests/test_hetero_market/test_hetero.py | 414 ------------------ 6 files changed, 555 deletions(-) delete mode 100644 tests/test_hetero_market/example_learnwares/config.py delete mode 100644 tests/test_hetero_market/example_learnwares/learnware.yaml delete mode 100644 tests/test_hetero_market/example_learnwares/model0.py delete mode 100644 tests/test_hetero_market/example_learnwares/model1.py delete mode 100644 tests/test_hetero_market/example_learnwares/requirements.txt delete mode 100644 tests/test_hetero_market/test_hetero.py diff --git a/tests/test_hetero_market/example_learnwares/config.py b/tests/test_hetero_market/example_learnwares/config.py deleted file mode 100644 index 1816b4c..0000000 --- a/tests/test_hetero_market/example_learnwares/config.py +++ /dev/null @@ -1,100 +0,0 @@ -input_shape_list = [20, 30] # 20-input shape of example learnware 0, 30-input shape of example learnware 1 - -input_description_list = [ - { - "Dimension": 20, - "Description": { # medical description - "0": "baseline value: Baseline Fetal Heart Rate (FHR)", - "1": "accelerations: Number of accelerations per second", - "2": "fetal_movement: Number of fetal movements per second", - "3": "uterine_contractions: Number of uterine contractions per second", - "4": "light_decelerations: Number of LDs per second", - "5": "severe_decelerations: Number of SDs per second", - "6": "prolongued_decelerations: Number of PDs per second", - "7": "abnormal_short_term_variability: Percentage of time with abnormal short term variability", - "8": "mean_value_of_short_term_variability: Mean value of short term variability", - "9": "percentage_of_time_with_abnormal_long_term_variability: Percentage of time with abnormal long term variability", - "10": "mean_value_of_long_term_variability: Mean value of long term variability", - "11": "histogram_width: Width of the histogram made using all values from a record", - "12": "histogram_min: Histogram minimum value", - "13": "histogram_max: Histogram maximum value", - "14": "histogram_number_of_peaks: Number of peaks in the exam histogram", - "15": "histogram_number_of_zeroes: Number of zeroes in the exam histogram", - "16": "histogram_mode: Hist mode", - "17": "histogram_mean: Hist mean", - "18": "histogram_median: Hist Median", - "19": "histogram_variance: Hist variance", - }, - }, - { - "Dimension": 30, - "Description": { # business description - "0": "This is a consecutive month number, used for convenience. For example, January 2013 is 0, February 2013 is 1,..., October 2015 is 33.", - "1": "This is the unique identifier for each shop.", - "2": "This is the unique identifier for each item.", - "3": "This is the code representing the city where the shop is located.", - "4": "This is the unique identifier for the category of the item.", - "5": "This is the code representing the type of the item.", - "6": "This is the code representing the subtype of the item.", - "7": "This is the number of this type of item sold in the shop one month ago.", - "8": "This is the number of this type of item sold in the shop two months ago.", - "9": "This is the number of this type of item sold in the shop three months ago.", - "10": "This is the number of this type of item sold in the shop six months ago.", - "11": "This is the number of this type of item sold in the shop twelve months ago.", - "12": "This is the average count of items sold one month ago.", - "13": "This is the average count of this type of item sold one month ago.", - "14": "This is the average count of this type of item sold two months ago.", - "15": "This is the average count of this type of item sold three months ago.", - "16": "This is the average count of this type of item sold six months ago.", - "17": "This is the average count of this type of item sold twelve months ago.", - "18": "This is the average count of items sold in the shop one month ago.", - "19": "This is the average count of items sold in the shop two months ago.", - "20": "This is the average count of items sold in the shop three months ago.", - "21": "This is the average count of items sold in the shop six months ago.", - "22": "This is the average count of items sold in the shop twelve months ago.", - "23": "This is the average count of items in the same category sold one month ago.", - "24": "This is the average count of items in the same category sold in the shop one month ago.", - "25": "This is the average count of items of the same type sold in the shop one month ago.", - "26": "This is the average count of items of the same subtype sold in the shop one month ago.", - "27": "This is the average count of items sold in the same city one month ago.", - "28": "This is the average count of this type of item sold in the same city one month ago.", - "29": "This is the average count of items of the same type sold one month ago.", - }, - }, -] - -output_description_list = [ - { - "Dimension": 1, - "Description": {"0": "length of stay: Length of hospital stay (days)"}, # medical description - }, - { - "Dimension": 1, - "Description": { # business description - "0": "sales of the item in the next day: Number of items sold in the next day" - }, - }, -] - -user_description_list = [ - { - "Dimension": 15, - "Description": { # medical description - "0": "Whether the patient is on thyroxine medication (0: No, 1: Yes)", - "1": "Whether the patient has been queried about thyroxine medication (0: No, 1: Yes)", - "2": "Whether the patient is on antithyroid medication (0: No, 1: Yes)", - "3": "Whether the patient has undergone thyroid surgery (0: No, 1: Yes)", - "4": "Whether the patient has been queried about hypothyroidism (0: No, 1: Yes)", - "5": "Whether the patient has been queried about hyperthyroidism (0: No, 1: Yes)", - "6": "Whether the patient is pregnant (0: No, 1: Yes)", - "7": "Whether the patient is sick (0: No, 1: Yes)", - "8": "Whether the patient has a tumor (0: No, 1: Yes)", - "9": "Whether the patient is taking lithium (0: No, 1: Yes)", - "10": "Whether the patient has a goitre (enlarged thyroid gland) (0: No, 1: Yes)", - "11": "Whether TSH (Thyroid Stimulating Hormone) level has been measured (0: No, 1: Yes)", - "12": "Whether T3 (Triiodothyronine) level has been measured (0: No, 1: Yes)", - "13": "Whether TT4 (Total Thyroxine) level has been measured (0: No, 1: Yes)", - "14": "Whether T4U (Thyroxine Utilization) level has been measured (0: No, 1: Yes)", - }, - } -] diff --git a/tests/test_hetero_market/example_learnwares/learnware.yaml b/tests/test_hetero_market/example_learnwares/learnware.yaml deleted file mode 100644 index 4a37a37..0000000 --- a/tests/test_hetero_market/example_learnwares/learnware.yaml +++ /dev/null @@ -1,8 +0,0 @@ -model: - class_name: MyModel - kwargs: {} -stat_specifications: - - module_path: learnware.specification - class_name: RKMETableSpecification - file_name: stat.json - kwargs: {} \ No newline at end of file diff --git a/tests/test_hetero_market/example_learnwares/model0.py b/tests/test_hetero_market/example_learnwares/model0.py deleted file mode 100644 index 45f64b7..0000000 --- a/tests/test_hetero_market/example_learnwares/model0.py +++ /dev/null @@ -1,16 +0,0 @@ -from learnware.model import BaseModel -import numpy as np -import joblib -import os - - -class MyModel(BaseModel): - def __init__(self): - super(MyModel, self).__init__(input_shape=(20,), output_shape=(1,)) - dir_path = os.path.dirname(os.path.abspath(__file__)) - model_path = os.path.join(dir_path, "ridge.pkl") - model = joblib.load(model_path) - self.model = model - - def predict(self, X: np.ndarray) -> np.ndarray: - return self.model.predict(X) diff --git a/tests/test_hetero_market/example_learnwares/model1.py b/tests/test_hetero_market/example_learnwares/model1.py deleted file mode 100644 index aca46b3..0000000 --- a/tests/test_hetero_market/example_learnwares/model1.py +++ /dev/null @@ -1,16 +0,0 @@ -from learnware.model import BaseModel -import numpy as np -import joblib -import os - - -class MyModel(BaseModel): - def __init__(self): - super(MyModel, self).__init__(input_shape=(30,), output_shape=(1,)) - dir_path = os.path.dirname(os.path.abspath(__file__)) - model_path = os.path.join(dir_path, "ridge.pkl") - model = joblib.load(model_path) - self.model = model - - def predict(self, X: np.ndarray) -> np.ndarray: - return self.model.predict(X) diff --git a/tests/test_hetero_market/example_learnwares/requirements.txt b/tests/test_hetero_market/example_learnwares/requirements.txt deleted file mode 100644 index 1da1c5f..0000000 --- a/tests/test_hetero_market/example_learnwares/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -learnware == 0.1.0.999 \ No newline at end of file diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py deleted file mode 100644 index 7b0740b..0000000 --- a/tests/test_hetero_market/test_hetero.py +++ /dev/null @@ -1,414 +0,0 @@ -import torch -import unittest -import os -import copy -import joblib -import zipfile -import numpy as np -import multiprocessing -from sklearn.linear_model import Ridge -from sklearn.datasets import make_regression -from shutil import copyfile, rmtree -from learnware.client import LearnwareClient -from sklearn.metrics import mean_squared_error - -import learnware -from learnware.market import instantiate_learnware_market, BaseUserInfo -from learnware.specification import RKMETableSpecification, generate_rkme_table_spec -from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser -from example_learnwares.config import ( - input_shape_list, - input_description_list, - output_description_list, - user_description_list, -) - -curr_root = os.path.dirname(os.path.abspath(__file__)) - -user_semantic = { - "Data": {"Values": ["Table"], "Type": "Class"}, - "Task": { - "Values": ["Regression"], - "Type": "Class", - }, - "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, - "Scenario": {"Values": ["Education"], "Type": "Tag"}, - "Description": {"Values": "", "Type": "String"}, - "Name": {"Values": "", "Type": "String"}, - "License": {"Values": ["MIT"], "Type": "Class"}, -} - - -def check_learnware(learnware_name, dir_path=os.path.join(curr_root, "learnware_pool")): - print(f"Checking Learnware: {learnware_name}") - zip_file_path = os.path.join(dir_path, learnware_name) - client = LearnwareClient() - # if check_learnware doesn't raise an exception, return True, otherwise, return false - try: - client.check_learnware(zip_file_path) - return True - except Exception as e: - print(f"Learnware {learnware_name} failed the check: {e}") - return False - - -class TestMarket(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - np.random.seed(2023) - learnware.init() - - def _init_learnware_market(self, organizer_kwargs=None): - """initialize learnware market""" - hetero_market = instantiate_learnware_market( - market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs - ) - return hetero_market - - def test_prepare_learnware_randomly(self, learnware_num=5): - self.zip_path_list = [] - - for i in range(learnware_num): - dir_path = os.path.join(curr_root, "learnware_pool", "ridge_%d" % (i)) - os.makedirs(dir_path, exist_ok=True) - - print("Preparing Learnware: %d" % (i)) - - example_learnware_idx = i % 2 - input_dim = input_shape_list[example_learnware_idx] - learnware_example_dir = "example_learnwares" - - X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_dim, noise=0.1, random_state=42) - - clf = Ridge(alpha=1.0) - clf.fit(X, y) - - joblib.dump(clf, os.path.join(dir_path, "ridge.pkl")) - - spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) - spec.save(os.path.join(dir_path, "stat.json")) - - init_file = os.path.join(dir_path, "__init__.py") - copyfile( - os.path.join(curr_root, learnware_example_dir, f"model{example_learnware_idx}.py"), init_file - ) # cp example_init.py init_file - - yaml_file = os.path.join(dir_path, "learnware.yaml") - copyfile( - os.path.join(curr_root, learnware_example_dir, "learnware.yaml"), yaml_file - ) # cp example.yaml yaml_file - - env_file = os.path.join(dir_path, "requirements.txt") - copyfile(os.path.join(curr_root, learnware_example_dir, "requirements.txt"), env_file) - - zip_file = dir_path + ".zip" - # zip -q -r -j zip_file dir_path - with zipfile.ZipFile(zip_file, "w") as zip_obj: - for foldername, subfolders, filenames in os.walk(dir_path): - for filename in filenames: - file_path = os.path.join(foldername, filename) - zip_info = zipfile.ZipInfo(filename) - zip_info.compress_type = zipfile.ZIP_STORED - with open(file_path, "rb") as file: - zip_obj.writestr(zip_info, file.read()) - - rmtree(dir_path) # rm -r dir_path - - self.zip_path_list.append(zip_file) - - def test_generated_learnwares(self): - curr_root = os.path.dirname(os.path.abspath(__file__)) - dir_path = os.path.join(curr_root, "learnware_pool") - - # Execute multi-process checking using Pool - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool() as pool: - results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) - - # Use an assert statement to ensure that all checks return True - self.assertTrue(all(results), "Not all learnwares passed the check") - - def test_upload_delete_learnware(self, learnware_num=5, delete=True): - hetero_market = self._init_learnware_market() - self.test_prepare_learnware_randomly(learnware_num) - self.learnware_num = learnware_num - - print("Total Item:", len(hetero_market)) - assert len(hetero_market) == 0, f"The market should be empty!" - - for idx, zip_path in enumerate(self.zip_path_list): - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) - semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) - semantic_spec["Input"] = input_description_list[idx % 2] - semantic_spec["Output"] = output_description_list[idx % 2] - hetero_market.add_learnware(zip_path, semantic_spec) - - print("Total Item:", len(hetero_market)) - assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - curr_inds = hetero_market.get_learnware_ids() - print("Available ids After Uploading Learnwares:", curr_inds) - assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - - if delete: - for learnware_id in curr_inds: - hetero_market.delete_learnware(learnware_id) - self.learnware_num -= 1 - assert ( - len(hetero_market) == self.learnware_num - ), f"The number of learnwares must be {self.learnware_num}!" - - curr_inds = hetero_market.get_learnware_ids() - print("Available ids After Deleting Learnwares:", curr_inds) - assert len(curr_inds) == 0, f"The market should be empty!" - - return hetero_market - - def test_train_market_model(self, learnware_num=5): - hetero_market = self._init_learnware_market( - organizer_kwargs={"auto_update": False, "auto_update_limit": learnware_num} - ) - self.test_prepare_learnware_randomly(learnware_num) - self.learnware_num = learnware_num - - print("Total Item:", len(hetero_market)) - assert len(hetero_market) == 0, f"The market should be empty!" - - for idx, zip_path in enumerate(self.zip_path_list): - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) - semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) - semantic_spec["Input"] = input_description_list[idx % 2] - semantic_spec["Output"] = output_description_list[idx % 2] - hetero_market.add_learnware(zip_path, semantic_spec) - - print("Total Item:", len(hetero_market)) - assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - curr_inds = hetero_market.get_learnware_ids() - print("Available ids After Uploading Learnwares:", curr_inds) - assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - - # organizer=hetero_market.learnware_organizer - # organizer.train(hetero_market.learnware_organizer.learnware_list.values()) - return hetero_market - - def test_search_semantics(self, learnware_num=5): - hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False) - print("Total Item:", len(hetero_market)) - assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" - - user_info = BaseUserInfo(semantic_spec=semantic_spec) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - - print("User info:", user_info.get_semantic_spec()) - print(f"Search result:") - assert len(single_result) == 1, f"Exact semantic search failed!" - for search_item in single_result: - semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() - print("Choose learnware:", search_item.learnware.id, semantic_spec1) - assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" - - semantic_spec["Name"]["Values"] = "laernwaer" - user_info = BaseUserInfo(semantic_spec=semantic_spec) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - - print("User info:", user_info.get_semantic_spec()) - print(f"Search result:") - assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" - for search_item in single_result: - semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() - print("Choose learnware:", search_item.learnware.id, semantic_spec1) - - def test_stat_search(self, learnware_num=5): - hetero_market = self.test_train_market_model(learnware_num) - print("Total Item:", len(hetero_market)) - - # hetero test - print("+++++ HETERO TEST ++++++") - user_dim = 15 - - test_folder = os.path.join(curr_root, "test_stat") - - for idx, zip_path in enumerate(self.zip_path_list): - unzip_dir = os.path.join(test_folder, f"{idx}") - - # unzip -o -q zip_path -d unzip_dir - if os.path.exists(unzip_dir): - rmtree(unzip_dir) - os.makedirs(unzip_dir, exist_ok=True) - with zipfile.ZipFile(zip_path, "r") as zip_obj: - zip_obj.extractall(path=unzip_dir) - - user_spec = RKMETableSpecification() - user_spec.load(os.path.join(unzip_dir, "stat.json")) - z = user_spec.get_z() - z = z[:, :user_dim] - device = user_spec.device - z = torch.tensor(z, device=device) - user_spec.z = z - - print(">> normal case test:") - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) - semantic_spec["Input"]["Dimension"] = user_dim - # keep only the first user_dim descriptions - semantic_spec["Input"]["Description"] = { - str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) - } - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - multiple_result = search_result.get_multiple_results() - - print(f"search result of user{idx}:") - for single_item in single_result: - print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") - - for multiple_item in multiple_result: - print( - f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" - ) - - # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec - print(">> test for key 'Task' has empty 'Values':") - semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} - - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - - assert len(single_result) == 0, f"Statistical search failed!" - - # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" - print(">> delele key 'Task' test:") - semantic_spec.pop("Task") - - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - - assert len(single_result) == 0, f"Statistical search failed!" - - # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." - print(">> mismatch dim test") - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) - semantic_spec["Input"]["Dimension"] = user_dim - 2 - semantic_spec["Input"]["Description"] = { - str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) - } - - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - - assert len(single_result) == 0, f"Statistical search failed!" - - rmtree(test_folder) # rm -r test_folder - - # homo test - print("\n+++++ HOMO TEST ++++++") - test_folder = os.path.join(curr_root, "test_stat") - - for idx, zip_path in enumerate(self.zip_path_list): - unzip_dir = os.path.join(test_folder, f"{idx}") - - # unzip -o -q zip_path -d unzip_dir - if os.path.exists(unzip_dir): - rmtree(unzip_dir) - os.makedirs(unzip_dir, exist_ok=True) - with zipfile.ZipFile(zip_path, "r") as zip_obj: - zip_obj.extractall(path=unzip_dir) - - user_spec = RKMETableSpecification() - user_spec.load(os.path.join(unzip_dir, "stat.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - multiple_result = search_result.get_multiple_results() - - assert len(single_result) >= 1, f"Statistical search failed!" - print(f"search result of user{idx}:") - for single_item in single_result: - print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") - - for multiple_item in multiple_result: - print(f"mixture_score: {multiple_item.score}\n") - mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) - print(f"mixture_learnware: {mixture_id}\n") - - rmtree(test_folder) # rm -r test_folder - - def test_model_reuse(self, learnware_num=5): - # generate toy regression problem - X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0) - - # generate rkme - user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) - - # generate specification - semantic_spec = copy.deepcopy(user_semantic) - semantic_spec["Input"] = user_description_list[0] - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - - # learnware market search - hetero_market = self.test_train_market_model(learnware_num) - search_result = hetero_market.search_learnware(user_info) - single_result = search_result.get_single_results() - multiple_result = search_result.get_multiple_results() - # print search results - for single_item in single_result: - print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") - - for multiple_item in multiple_result: - print( - f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" - ) - - # single model reuse - hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") - hetero_learnware.align(user_spec, X[:100], y[:100]) - single_predict_y = hetero_learnware.predict(X) - - # multi model reuse - hetero_learnware_list = [] - for learnware in multiple_result[0].learnwares: - hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") - hetero_learnware.align(user_spec, X[:100], y[:100]) - hetero_learnware_list.append(hetero_learnware) - - # Use averaging ensemble reuser to reuse the searched learnwares to make prediction - reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean") - ensemble_predict_y = reuse_ensemble.predict(user_data=X) - - # Use ensemble pruning reuser to reuse the searched learnwares to make prediction - reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression") - reuse_ensemble.fit(X[:100], y[:100]) - ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X) - - print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False)) - print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False)) - print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False)) - - -def suite(): - _suite = unittest.TestSuite() - _suite.addTest(TestMarket("test_prepare_learnware_randomly")) - _suite.addTest(TestMarket("test_generated_learnwares")) - _suite.addTest(TestMarket("test_upload_delete_learnware")) - _suite.addTest(TestMarket("test_train_market_model")) - _suite.addTest(TestMarket("test_search_semantics")) - _suite.addTest(TestMarket("test_stat_search")) - _suite.addTest(TestMarket("test_model_reuse")) - return _suite - - -if __name__ == "__main__": - runner = unittest.TextTestRunner() - runner.run(suite())