Browse Source

[MNT] refactor hetero workflow and make it runnable

tags/v0.3.2
bxdd 2 years ago
parent
commit
e95b6cc5e7
3 changed files with 423 additions and 5 deletions
  1. +100
    -0
      tests/test_workflow/hetero_config.py
  2. +321
    -0
      tests/test_workflow/test_hetero_workflow.py
  3. +2
    -5
      tests/test_workflow/test_workflow.py

+ 100
- 0
tests/test_workflow/hetero_config.py View File

@@ -0,0 +1,100 @@
input_shape_list = [20, 30] # 20-input shape of example learnware 0, 30-input shape of example learnware 1

input_description_list = [
{
"Dimension": 20,
"Description": { # medical description
"0": "baseline value: Baseline Fetal Heart Rate (FHR)",
"1": "accelerations: Number of accelerations per second",
"2": "fetal_movement: Number of fetal movements per second",
"3": "uterine_contractions: Number of uterine contractions per second",
"4": "light_decelerations: Number of LDs per second",
"5": "severe_decelerations: Number of SDs per second",
"6": "prolongued_decelerations: Number of PDs per second",
"7": "abnormal_short_term_variability: Percentage of time with abnormal short term variability",
"8": "mean_value_of_short_term_variability: Mean value of short term variability",
"9": "percentage_of_time_with_abnormal_long_term_variability: Percentage of time with abnormal long term variability",
"10": "mean_value_of_long_term_variability: Mean value of long term variability",
"11": "histogram_width: Width of the histogram made using all values from a record",
"12": "histogram_min: Histogram minimum value",
"13": "histogram_max: Histogram maximum value",
"14": "histogram_number_of_peaks: Number of peaks in the exam histogram",
"15": "histogram_number_of_zeroes: Number of zeroes in the exam histogram",
"16": "histogram_mode: Hist mode",
"17": "histogram_mean: Hist mean",
"18": "histogram_median: Hist Median",
"19": "histogram_variance: Hist variance",
},
},
{
"Dimension": 30,
"Description": { # business description
"0": "This is a consecutive month number, used for convenience. For example, January 2013 is 0, February 2013 is 1,..., October 2015 is 33.",
"1": "This is the unique identifier for each shop.",
"2": "This is the unique identifier for each item.",
"3": "This is the code representing the city where the shop is located.",
"4": "This is the unique identifier for the category of the item.",
"5": "This is the code representing the type of the item.",
"6": "This is the code representing the subtype of the item.",
"7": "This is the number of this type of item sold in the shop one month ago.",
"8": "This is the number of this type of item sold in the shop two months ago.",
"9": "This is the number of this type of item sold in the shop three months ago.",
"10": "This is the number of this type of item sold in the shop six months ago.",
"11": "This is the number of this type of item sold in the shop twelve months ago.",
"12": "This is the average count of items sold one month ago.",
"13": "This is the average count of this type of item sold one month ago.",
"14": "This is the average count of this type of item sold two months ago.",
"15": "This is the average count of this type of item sold three months ago.",
"16": "This is the average count of this type of item sold six months ago.",
"17": "This is the average count of this type of item sold twelve months ago.",
"18": "This is the average count of items sold in the shop one month ago.",
"19": "This is the average count of items sold in the shop two months ago.",
"20": "This is the average count of items sold in the shop three months ago.",
"21": "This is the average count of items sold in the shop six months ago.",
"22": "This is the average count of items sold in the shop twelve months ago.",
"23": "This is the average count of items in the same category sold one month ago.",
"24": "This is the average count of items in the same category sold in the shop one month ago.",
"25": "This is the average count of items of the same type sold in the shop one month ago.",
"26": "This is the average count of items of the same subtype sold in the shop one month ago.",
"27": "This is the average count of items sold in the same city one month ago.",
"28": "This is the average count of this type of item sold in the same city one month ago.",
"29": "This is the average count of items of the same type sold one month ago.",
},
},
]

output_description_list = [
{
"Dimension": 1,
"Description": {"0": "length of stay: Length of hospital stay (days)"}, # medical description
},
{
"Dimension": 1,
"Description": { # business description
"0": "sales of the item in the next day: Number of items sold in the next day"
},
},
]

user_description_list = [
{
"Dimension": 15,
"Description": { # medical description
"0": "Whether the patient is on thyroxine medication (0: No, 1: Yes)",
"1": "Whether the patient has been queried about thyroxine medication (0: No, 1: Yes)",
"2": "Whether the patient is on antithyroid medication (0: No, 1: Yes)",
"3": "Whether the patient has undergone thyroid surgery (0: No, 1: Yes)",
"4": "Whether the patient has been queried about hypothyroidism (0: No, 1: Yes)",
"5": "Whether the patient has been queried about hyperthyroidism (0: No, 1: Yes)",
"6": "Whether the patient is pregnant (0: No, 1: Yes)",
"7": "Whether the patient is sick (0: No, 1: Yes)",
"8": "Whether the patient has a tumor (0: No, 1: Yes)",
"9": "Whether the patient is taking lithium (0: No, 1: Yes)",
"10": "Whether the patient has a goitre (enlarged thyroid gland) (0: No, 1: Yes)",
"11": "Whether TSH (Thyroid Stimulating Hormone) level has been measured (0: No, 1: Yes)",
"12": "Whether T3 (Triiodothyronine) level has been measured (0: No, 1: Yes)",
"13": "Whether TT4 (Total Thyroxine) level has been measured (0: No, 1: Yes)",
"14": "Whether T4U (Thyroxine Utilization) level has been measured (0: No, 1: Yes)",
},
}
]

+ 321
- 0
tests/test_workflow/test_hetero_workflow.py View File

@@ -0,0 +1,321 @@
import torch
import pickle
import unittest
import os
import logging
import tempfile
import zipfile
from sklearn.linear_model import Ridge
from sklearn.datasets import make_regression
from shutil import copyfile, rmtree
from sklearn.metrics import mean_squared_error

import learnware
learnware.init(logging_level=logging.WARNING)

from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate

from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list


curr_root = os.path.dirname(os.path.abspath(__file__))

class TestHeteroWorkflow(unittest.TestCase):
universal_semantic_config = {
"data_type": "Table",
"task_type": "Regression",
"library_type": "Scikit-learn",
"scenarios": "Education",
"license": "MIT",
}

def _init_learnware_market(self, organizer_kwargs=None):
"""initialize learnware market"""
hetero_market = instantiate_learnware_market(
market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
)
return hetero_market

def test_prepare_learnware_randomly(self, learnware_num=5):
self.zip_path_list = []

for i in range(learnware_num):
learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
os.makedirs(learnware_pool_dirpath, exist_ok=True)
learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
print("Preparing Learnware: %d" % (i))

X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42)
clf = Ridge(alpha=1.0)
clf.fit(X, y)
pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
with open(pickle_filepath, "wb") as fout:
pickle.dump(clf, fout)

spec = generate_rkme_table_spec(X=X, gamma=0.1)
spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
spec.save(spec_filepath)

LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"],
)
self.zip_path_list.append(learnware_zippath)

def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num

print("Total Item:", len(hetero_market))
assert len(hetero_market) == 0, f"The market should be empty!"

for idx, zip_path in enumerate(self.zip_path_list):
semantic_spec = generate_semantic_spec(
name=f"learnware_{idx}",
description=f"test_learnware_number_{idx}",
input_description=input_description_list[idx % 2],
output_description=output_description_list[idx % 2],
**self.universal_semantic_config
)
hetero_market.add_learnware(zip_path, semantic_spec)

print("Total Item:", len(hetero_market))
assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
curr_inds = hetero_market.get_learnware_ids()
print("Available ids After Uploading Learnwares:", curr_inds)
assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

if delete:
for learnware_id in curr_inds:
hetero_market.delete_learnware(learnware_id)
self.learnware_num -= 1
assert (
len(hetero_market) == self.learnware_num
), f"The number of learnwares must be {self.learnware_num}!"

curr_inds = hetero_market.get_learnware_ids()
print("Available ids After Deleting Learnwares:", curr_inds)
assert len(curr_inds) == 0, f"The market should be empty!"

return hetero_market
def test_upload_delete_learnware(self, learnware_num=5, delete=True):
hetero_market = self._init_learnware_market()
return self._upload_delete_learnware(hetero_market, learnware_num, delete)

def test_train_market_model(self, learnware_num=5, delete=False):
hetero_market = self._init_learnware_market(
organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num}
)
hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete)
# organizer=hetero_market.learnware_organizer
# organizer.train(hetero_market.learnware_organizer.learnware_list.values())
return hetero_market

def test_search_semantics(self, learnware_num=5):
hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(hetero_market))
assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

semantic_spec = generate_semantic_spec(
name=f"learnware_{learnware_num - 1}",
**self.universal_semantic_config,
)
user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

print(f"Search result1:")
assert len(single_result) == 1, f"Exact semantic search failed!"
for search_item in single_result:
semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
print("Choose learnware:", search_item.learnware.id)
assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"

semantic_spec["Name"]["Values"] = "laernwaer"
user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

print(f"Search result2:")
assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!"
for search_item in single_result:
print("Choose learnware:", search_item.learnware.id)

def test_hetero_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market))
user_dim = 15

with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
for idx, zip_path in enumerate(self.zip_path_list):
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=test_folder)

user_spec = RKMETableSpecification()
user_spec.load(os.path.join(test_folder, "stat_spec.json"))
z = user_spec.get_z()
z = z[:, :user_dim]
device = user_spec.device
z = torch.tensor(z, device=device)
user_spec.z = z

print(">> normal case test:")
semantic_spec = generate_semantic_spec(
input_description={
"Dimension": user_dim,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
},
**self.universal_semantic_config,
)
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
print(f"search result of user{idx}:")
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")

for multiple_item in multiple_result:
print(
f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
)

# inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
print(">> test for key 'Task' has empty 'Values':")
semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

assert len(single_result) == 0, f"Statistical search failed!"

# delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
print(">> delele key 'Task' test:")
semantic_spec.pop("Task")
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

assert len(single_result) == 0, f"Statistical search failed!"

# modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
print(">> mismatch dim test")
semantic_spec = generate_semantic_spec(
input_description={
"Dimension": user_dim - 2,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
},
**self.universal_semantic_config,
)
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

assert len(single_result) == 0, f"Statistical search failed!"

def test_homo_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market))
with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
for idx, zip_path in enumerate(self.zip_path_list):
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=test_folder)

user_spec = RKMETableSpecification()
user_spec.load(os.path.join(test_folder, "stat_spec.json"))
user_semantic = generate_semantic_spec(**self.universal_semantic_config)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()

assert len(single_result) >= 1, f"Statistical search failed!"
print(f"search result of user{idx}:")
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")

for multiple_item in multiple_result:
print(f"mixture_score: {multiple_item.score}\n")
mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
print(f"mixture_learnware: {mixture_id}\n")

def test_model_reuse(self, learnware_num=5):
# generate toy regression problem
X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)

# generate rkme
user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)

# generate specification
semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config)
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})

# learnware market search
hetero_market = self.test_train_market_model(learnware_num, delete=False)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
# print search results
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")

for multiple_item in multiple_result:
print(
f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
)

# single model reuse
hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
hetero_learnware.align(user_spec, X[:100], y[:100])
single_predict_y = hetero_learnware.predict(X)

# multi model reuse
hetero_learnware_list = []
for learnware in multiple_result[0].learnwares:
hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression")
hetero_learnware.align(user_spec, X[:100], y[:100])
hetero_learnware_list.append(hetero_learnware)

# Use averaging ensemble reuser to reuse the searched learnwares to make prediction
reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
ensemble_predict_y = reuse_ensemble.predict(user_data=X)

# Use ensemble pruning reuser to reuse the searched learnwares to make prediction
reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
reuse_ensemble.fit(X[:100], y[:100])
ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)

print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))


def suite():
_suite = unittest.TestSuite()
#_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
#_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
#_suite.addTest(TestHeteroWorkflow("test_train_market_model"))
_suite.addTest(TestHeteroWorkflow("test_search_semantics"))
_suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
_suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))
_suite.addTest(TestHeteroWorkflow("test_model_reuse"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())

+ 2
- 5
tests/test_workflow/test_workflow.py View File

@@ -29,10 +29,6 @@ class TestWorkflow(unittest.TestCase):
"license": "MIT",
}
@classmethod
def setUpClass(cls):
pass
def _init_learnware_market(self):
"""initialize learnware market"""
easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True)
@@ -62,7 +58,8 @@ class TestWorkflow(unittest.TestCase):
LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification")
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"],
)
self.zip_path_list.append(learnware_zippath)


Loading…
Cancel
Save