| @@ -9,7 +9,7 @@ from tqdm import trange | |||
| from loguru import logger | |||
| from learnware.learnware import Learnware | |||
| from learnware.specification import RKMEStatSpecification | |||
| from learnware.specification import RKMETableSpecification | |||
| from learnware.specification.regular.table.rkme import choose_device | |||
| from ..base import BaseReuser | |||
| @@ -26,7 +26,7 @@ class FeatureAligner(BaseReuser): | |||
| self.device = choose_device(cuda_idx=cuda_idx) | |||
| def fit(self, user_rkme): | |||
| target_rkme=self.learnware.specification.get_stat_spec()["RKMEStatSpecification"] | |||
| target_rkme=self.learnware.specification.get_stat_spec()["RKMETableSpecification"] | |||
| trainer=FeatureAlignmentTrainer(target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments) | |||
| self.align_model=trainer.model | |||
| self.align_model.eval() | |||
| @@ -91,8 +91,8 @@ class FeatureAlignmentTrainer(): | |||
| def __init__( | |||
| self, | |||
| target_rkme: RKMEStatSpecification, # (X, weight) | |||
| user_rkme: RKMEStatSpecification, # (X, weight) | |||
| target_rkme: RKMETableSpecification, # (X, weight) | |||
| user_rkme: RKMETableSpecification, # (X, weight) | |||
| extra_labeled_data: Any = None, | |||
| target_learnware: Learnware = None, | |||
| num_epoch: int = 50, | |||
| @@ -78,4 +78,27 @@ output_description_list=[ | |||
| }, | |||
| }, | |||
| ] | |||
| user_description_list=[ | |||
| { | |||
| "Dimension": 15, | |||
| "Description": { # medical description | |||
| "0": "Whether the patient is on thyroxine medication (0: No, 1: Yes)", | |||
| "1": "Whether the patient has been queried about thyroxine medication (0: No, 1: Yes)", | |||
| "2": "Whether the patient is on antithyroid medication (0: No, 1: Yes)", | |||
| "3": "Whether the patient has undergone thyroid surgery (0: No, 1: Yes)", | |||
| "4": "Whether the patient has been queried about hypothyroidism (0: No, 1: Yes)", | |||
| "5": "Whether the patient has been queried about hyperthyroidism (0: No, 1: Yes)", | |||
| "6": "Whether the patient is pregnant (0: No, 1: Yes)", | |||
| "7": "Whether the patient is sick (0: No, 1: Yes)", | |||
| "8": "Whether the patient has a tumor (0: No, 1: Yes)", | |||
| "9": "Whether the patient is taking lithium (0: No, 1: Yes)", | |||
| "10": "Whether the patient has a goitre (enlarged thyroid gland) (0: No, 1: Yes)", | |||
| "11": "Whether TSH (Thyroid Stimulating Hormone) level has been measured (0: No, 1: Yes)", | |||
| "12": "Whether T3 (Triiodothyronine) level has been measured (0: No, 1: Yes)", | |||
| "13": "Whether TT4 (Total Thyroxine) level has been measured (0: No, 1: Yes)", | |||
| "14": "Whether T4U (Thyroxine Utilization) level has been measured (0: No, 1: Yes)" | |||
| }, | |||
| } | |||
| ] | |||
| @@ -10,11 +10,13 @@ from sklearn.datasets import make_regression | |||
| from shutil import copyfile, rmtree | |||
| from multiprocessing import Pool | |||
| from learnware.client import LearnwareClient | |||
| from sklearn.metrics import mean_squared_error | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_spec | |||
| from example_learnwares.config import input_shape_list, input_description_list, output_description_list | |||
| from learnware.reuse import HeteroMapTableReuser | |||
| from example_learnwares.config import input_shape_list, input_description_list, output_description_list, user_description_list | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| @@ -286,6 +288,43 @@ class TestMarket(unittest.TestCase): | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_model_reuse(self, learnware_num=5): | |||
| # generate toy regression problem | |||
| X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0) | |||
| # generate rkme | |||
| user_spec = generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) | |||
| # generate specification | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Input"] = user_description_list[0] | |||
| user_info=BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| # learnware market search | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| # model reuse | |||
| print([learnware.id for learnware in single_learnware_list]) | |||
| reuser=HeteroMapTableReuser(single_learnware_list[0], task_type='regression') | |||
| reuser.fit(user_spec) | |||
| y_pred=reuser.predict(X) | |||
| # calculate rmse | |||
| rmse=mean_squared_error(y, y_pred, squared=False) | |||
| print(f"rmse not finetune: {rmse}") | |||
| # finetune | |||
| reuser.finetune(X[:100], y[:100]) | |||
| y_pred=reuser.predict(X) | |||
| rmse=mean_squared_error(y, y_pred, squared=False) | |||
| print(f"rmse finetune: {rmse}") | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| @@ -294,7 +333,8 @@ def suite(): | |||
| # _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| # _suite.addTest(TestMarket("test_train_market_model")) | |||
| # _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| # _suite.addTest(TestMarket("test_stat_search")) | |||
| _suite.addTest(TestMarket("test_model_reuse")) | |||
| return _suite | |||