Browse Source

[ENH] add test_model_reuse and fix bug in reuser package

tags/v0.3.2
Peng Tan 2 years ago
parent
commit
2fbfb3948b
3 changed files with 69 additions and 6 deletions
  1. +4
    -4
      learnware/reuse/hetero_reuser/feature_alignment.py
  2. +23
    -0
      tests/test_market/test_hetero_market/example_learnwares/config.py
  3. +42
    -2
      tests/test_market/test_hetero_market/test_hetero.py

+ 4
- 4
learnware/reuse/hetero_reuser/feature_alignment.py View File

@@ -9,7 +9,7 @@ from tqdm import trange
from loguru import logger

from learnware.learnware import Learnware
from learnware.specification import RKMEStatSpecification
from learnware.specification import RKMETableSpecification
from learnware.specification.regular.table.rkme import choose_device

from ..base import BaseReuser
@@ -26,7 +26,7 @@ class FeatureAligner(BaseReuser):
self.device = choose_device(cuda_idx=cuda_idx)

def fit(self, user_rkme):
target_rkme=self.learnware.specification.get_stat_spec()["RKMEStatSpecification"]
target_rkme=self.learnware.specification.get_stat_spec()["RKMETableSpecification"]
trainer=FeatureAlignmentTrainer(target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments)
self.align_model=trainer.model
self.align_model.eval()
@@ -91,8 +91,8 @@ class FeatureAlignmentTrainer():

def __init__(
self,
target_rkme: RKMEStatSpecification, # (X, weight)
user_rkme: RKMEStatSpecification, # (X, weight)
target_rkme: RKMETableSpecification, # (X, weight)
user_rkme: RKMETableSpecification, # (X, weight)
extra_labeled_data: Any = None,
target_learnware: Learnware = None,
num_epoch: int = 50,


+ 23
- 0
tests/test_market/test_hetero_market/example_learnwares/config.py View File

@@ -78,4 +78,27 @@ output_description_list=[
},
},
]

user_description_list=[
{
"Dimension": 15,
"Description": { # medical description
"0": "Whether the patient is on thyroxine medication (0: No, 1: Yes)",
"1": "Whether the patient has been queried about thyroxine medication (0: No, 1: Yes)",
"2": "Whether the patient is on antithyroid medication (0: No, 1: Yes)",
"3": "Whether the patient has undergone thyroid surgery (0: No, 1: Yes)",
"4": "Whether the patient has been queried about hypothyroidism (0: No, 1: Yes)",
"5": "Whether the patient has been queried about hyperthyroidism (0: No, 1: Yes)",
"6": "Whether the patient is pregnant (0: No, 1: Yes)",
"7": "Whether the patient is sick (0: No, 1: Yes)",
"8": "Whether the patient has a tumor (0: No, 1: Yes)",
"9": "Whether the patient is taking lithium (0: No, 1: Yes)",
"10": "Whether the patient has a goitre (enlarged thyroid gland) (0: No, 1: Yes)",
"11": "Whether TSH (Thyroid Stimulating Hormone) level has been measured (0: No, 1: Yes)",
"12": "Whether T3 (Triiodothyronine) level has been measured (0: No, 1: Yes)",
"13": "Whether TT4 (Total Thyroxine) level has been measured (0: No, 1: Yes)",
"14": "Whether T4U (Thyroxine Utilization) level has been measured (0: No, 1: Yes)"
},
}
]

+ 42
- 2
tests/test_market/test_hetero_market/test_hetero.py View File

@@ -10,11 +10,13 @@ from sklearn.datasets import make_regression
from shutil import copyfile, rmtree
from multiprocessing import Pool
from learnware.client import LearnwareClient
from sklearn.metrics import mean_squared_error

import learnware
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.specification import RKMETableSpecification, generate_rkme_spec
from example_learnwares.config import input_shape_list, input_description_list, output_description_list
from learnware.reuse import HeteroMapTableReuser
from example_learnwares.config import input_shape_list, input_description_list, output_description_list, user_description_list

curr_root = os.path.dirname(os.path.abspath(__file__))

@@ -286,6 +288,43 @@ class TestMarket(unittest.TestCase):

rmtree(test_folder) # rm -r test_folder

def test_model_reuse(self, learnware_num=5):
# generate toy regression problem
X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)

# generate rkme
user_spec = generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0)

# generate specification
semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Input"] = user_description_list[0]
user_info=BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})

# learnware market search
hetero_market = self.test_train_market_model(learnware_num)
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)

# model reuse
print([learnware.id for learnware in single_learnware_list])
reuser=HeteroMapTableReuser(single_learnware_list[0], task_type='regression')
reuser.fit(user_spec)
y_pred=reuser.predict(X)
# calculate rmse
rmse=mean_squared_error(y, y_pred, squared=False)
print(f"rmse not finetune: {rmse}")

# finetune
reuser.finetune(X[:100], y[:100])
y_pred=reuser.predict(X)
rmse=mean_squared_error(y, y_pred, squared=False)
print(f"rmse finetune: {rmse}")


def suite():
_suite = unittest.TestSuite()
@@ -294,7 +333,8 @@ def suite():
# _suite.addTest(TestMarket("test_upload_delete_learnware"))
# _suite.addTest(TestMarket("test_train_market_model"))
# _suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_stat_search"))
# _suite.addTest(TestMarket("test_stat_search"))
_suite.addTest(TestMarket("test_model_reuse"))
return _suite




Loading…
Cancel
Save