From 2fbfb3948bd597a645bcc69dc629e5499ca81aa2 Mon Sep 17 00:00:00 2001 From: Peng Tan Date: Wed, 8 Nov 2023 21:12:57 +0800 Subject: [PATCH] [ENH] add test_model_reuse and fix bug in reuser package --- .../reuse/hetero_reuser/feature_alignment.py | 8 ++-- .../example_learnwares/config.py | 23 ++++++++++ .../test_hetero_market/test_hetero.py | 44 ++++++++++++++++++- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/learnware/reuse/hetero_reuser/feature_alignment.py b/learnware/reuse/hetero_reuser/feature_alignment.py index 4cbe71d..112d749 100644 --- a/learnware/reuse/hetero_reuser/feature_alignment.py +++ b/learnware/reuse/hetero_reuser/feature_alignment.py @@ -9,7 +9,7 @@ from tqdm import trange from loguru import logger from learnware.learnware import Learnware -from learnware.specification import RKMEStatSpecification +from learnware.specification import RKMETableSpecification from learnware.specification.regular.table.rkme import choose_device from ..base import BaseReuser @@ -26,7 +26,7 @@ class FeatureAligner(BaseReuser): self.device = choose_device(cuda_idx=cuda_idx) def fit(self, user_rkme): - target_rkme=self.learnware.specification.get_stat_spec()["RKMEStatSpecification"] + target_rkme=self.learnware.specification.get_stat_spec()["RKMETableSpecification"] trainer=FeatureAlignmentTrainer(target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments) self.align_model=trainer.model self.align_model.eval() @@ -91,8 +91,8 @@ class FeatureAlignmentTrainer(): def __init__( self, - target_rkme: RKMEStatSpecification, # (X, weight) - user_rkme: RKMEStatSpecification, # (X, weight) + target_rkme: RKMETableSpecification, # (X, weight) + user_rkme: RKMETableSpecification, # (X, weight) extra_labeled_data: Any = None, target_learnware: Learnware = None, num_epoch: int = 50, diff --git a/tests/test_market/test_hetero_market/example_learnwares/config.py b/tests/test_market/test_hetero_market/example_learnwares/config.py index 941109a..b4d4fb4 100644 --- a/tests/test_market/test_hetero_market/example_learnwares/config.py +++ b/tests/test_market/test_hetero_market/example_learnwares/config.py @@ -78,4 +78,27 @@ output_description_list=[ }, }, +] + +user_description_list=[ + { + "Dimension": 15, + "Description": { # medical description + "0": "Whether the patient is on thyroxine medication (0: No, 1: Yes)", + "1": "Whether the patient has been queried about thyroxine medication (0: No, 1: Yes)", + "2": "Whether the patient is on antithyroid medication (0: No, 1: Yes)", + "3": "Whether the patient has undergone thyroid surgery (0: No, 1: Yes)", + "4": "Whether the patient has been queried about hypothyroidism (0: No, 1: Yes)", + "5": "Whether the patient has been queried about hyperthyroidism (0: No, 1: Yes)", + "6": "Whether the patient is pregnant (0: No, 1: Yes)", + "7": "Whether the patient is sick (0: No, 1: Yes)", + "8": "Whether the patient has a tumor (0: No, 1: Yes)", + "9": "Whether the patient is taking lithium (0: No, 1: Yes)", + "10": "Whether the patient has a goitre (enlarged thyroid gland) (0: No, 1: Yes)", + "11": "Whether TSH (Thyroid Stimulating Hormone) level has been measured (0: No, 1: Yes)", + "12": "Whether T3 (Triiodothyronine) level has been measured (0: No, 1: Yes)", + "13": "Whether TT4 (Total Thyroxine) level has been measured (0: No, 1: Yes)", + "14": "Whether T4U (Thyroxine Utilization) level has been measured (0: No, 1: Yes)" + }, + } ] \ No newline at end of file diff --git a/tests/test_market/test_hetero_market/test_hetero.py b/tests/test_market/test_hetero_market/test_hetero.py index 81b6493..19c5ad5 100644 --- a/tests/test_market/test_hetero_market/test_hetero.py +++ b/tests/test_market/test_hetero_market/test_hetero.py @@ -10,11 +10,13 @@ from sklearn.datasets import make_regression from shutil import copyfile, rmtree from multiprocessing import Pool from learnware.client import LearnwareClient +from sklearn.metrics import mean_squared_error import learnware from learnware.market import instantiate_learnware_market, BaseUserInfo from learnware.specification import RKMETableSpecification, generate_rkme_spec -from example_learnwares.config import input_shape_list, input_description_list, output_description_list +from learnware.reuse import HeteroMapTableReuser +from example_learnwares.config import input_shape_list, input_description_list, output_description_list, user_description_list curr_root = os.path.dirname(os.path.abspath(__file__)) @@ -286,6 +288,43 @@ class TestMarket(unittest.TestCase): rmtree(test_folder) # rm -r test_folder + def test_model_reuse(self, learnware_num=5): + # generate toy regression problem + X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0) + + # generate rkme + user_spec = generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + + # generate specification + semantic_spec = copy.deepcopy(user_semantic) + semantic_spec["Input"] = user_description_list[0] + user_info=BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) + + # learnware market search + hetero_market = self.test_train_market_model(learnware_num) + ( + sorted_score_list, + single_learnware_list, + mixture_score, + mixture_learnware_list, + ) = hetero_market.search_learnware(user_info) + + # model reuse + print([learnware.id for learnware in single_learnware_list]) + reuser=HeteroMapTableReuser(single_learnware_list[0], task_type='regression') + reuser.fit(user_spec) + y_pred=reuser.predict(X) + + # calculate rmse + rmse=mean_squared_error(y, y_pred, squared=False) + print(f"rmse not finetune: {rmse}") + + # finetune + reuser.finetune(X[:100], y[:100]) + y_pred=reuser.predict(X) + rmse=mean_squared_error(y, y_pred, squared=False) + print(f"rmse finetune: {rmse}") + def suite(): _suite = unittest.TestSuite() @@ -294,7 +333,8 @@ def suite(): # _suite.addTest(TestMarket("test_upload_delete_learnware")) # _suite.addTest(TestMarket("test_train_market_model")) # _suite.addTest(TestMarket("test_search_semantics")) - _suite.addTest(TestMarket("test_stat_search")) + # _suite.addTest(TestMarket("test_stat_search")) + _suite.addTest(TestMarket("test_model_reuse")) return _suite