diff --git a/examples/dataset_table_workflow/base.py b/examples/dataset_table_workflow/base.py index 0381f28..571f4fa 100644 --- a/examples/dataset_table_workflow/base.py +++ b/examples/dataset_table_workflow/base.py @@ -19,9 +19,6 @@ from utils import set_seed logger = get_module_logger("base_table", level="INFO") -# for quick test only -from learnware.market.heterogeneous import HeteroMapTableOrganizer - class TableWorkflow: def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): self.root_path = os.path.abspath(os.path.join(__file__, "..")) @@ -29,10 +26,7 @@ class TableWorkflow: self.curves_result_path = os.path.join(self.result_path, "curves") os.makedirs(self.result_path, exist_ok=True) os.makedirs(self.curves_result_path, exist_ok=True) - # if name == "hetero": - # set_seed(42) - self._prepare_market(benchmark_config, name, rebuild, retrain) - + self._prepare_market(benchmark_config, name, rebuild, retrain) @staticmethod def _limited_data(method, test_info, loss_func): @@ -81,14 +75,6 @@ class TableWorkflow: self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0]) self.user_semantic["Name"]["Values"] = "" - # if retrain == True and rebuild == False: - # logger.info(f"training learnwares: {self.market.get_learnware_ids()[::-1]}") - # market_mapping = HeteroMapTableOrganizer.train(self.market.get_learnwares()[::-1], save_dir='test_model.bin', **market_mapping_params) - # self.market.learnware_organizer.market_mapping = market_mapping - # self.market.learnware_organizer._update_learnware_hetero_spec(self.market.get_learnware_ids()[::-1]) - - # return - if len(self.market) == 0 or rebuild == True: for learnware_id in self.benchmark.learnware_ids: with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir: diff --git a/examples/dataset_table_workflow/config.py b/examples/dataset_table_workflow/config.py index 5b6a74f..7f0d5de 100644 --- a/examples/dataset_table_workflow/config.py +++ b/examples/dataset_table_workflow/config.py @@ -216,7 +216,7 @@ hetero_cross_feat_eng_benchmark_config = BenchmarkConfig( "00000912" ], test_data_path="PFS/test_data.zip", - # train_data_path="PFS/train_data.zip", + train_data_path="PFS/train_data.zip", extra_info_path="PFS/extra_info.zip" ) diff --git a/examples/dataset_table_workflow/hetero.py b/examples/dataset_table_workflow/hetero.py index db1c1a7..3caa5d3 100644 --- a/examples/dataset_table_workflow/hetero.py +++ b/examples/dataset_table_workflow/hetero.py @@ -39,7 +39,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): ) logger.info(f"Searching Market for user: {user}_{idx}") - search_result = self.market.search_learnware(user_info) + search_result = self.market.search_learnware(user_info, max_search_num=10) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() @@ -53,15 +53,15 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): pred_y = single_hetero_learnware.predict(test_x) single_score_list.append(loss_func_rmse(pred_y, test_y)) - # rmse_list = [] - # for learnware in all_learnwares: - # hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) - # hetero_learnware.align(user_rkme=user_stat_spec) - # pred_y = hetero_learnware.predict(test_x) - # rmse_list.append(loss_func_rmse(pred_y, test_y)) - # logger.info( - # f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}" - # ) + rmse_list = [] + for learnware in all_learnwares: + hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) + hetero_learnware.align(user_rkme=user_stat_spec) + pred_y = hetero_learnware.predict(test_x) + rmse_list.append(loss_func_rmse(pred_y, test_y)) + logger.info( + f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}" + ) if len(multiple_result) > 0: mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) @@ -83,21 +83,21 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): ensemble_score_list.append(ensemble_score) logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}") - # learnware_rmse_list.append(rmse_list) + learnware_rmse_list.append(rmse_list) - # single_list = np.array(learnware_rmse_list) - # avg_score_list = [np.mean(lst, axis=0) for lst in single_list] - # oracle_score_list = [np.min(lst, axis=0) for lst in single_list] + single_list = np.array(learnware_rmse_list) + avg_score_list = [np.mean(lst, axis=0) for lst in single_list] + oracle_score_list = [np.min(lst, axis=0) for lst in single_list] logger.info( - "RMSE of selected learnware: %.3f +/- %.3f" # , Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" + "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" % ( np.mean(single_score_list), np.std(single_score_list), - # np.mean(avg_score_list), - # np.std(avg_score_list), - # np.mean(oracle_score_list), - # np.std(oracle_score_list), + np.mean(avg_score_list), + np.std(avg_score_list), + np.mean(oracle_score_list), + np.std(oracle_score_list) ) ) logger.info( diff --git a/examples/dataset_table_workflow/workflow.py b/examples/dataset_table_workflow/workflow.py index e03dca3..e079b6e 100644 --- a/examples/dataset_table_workflow/workflow.py +++ b/examples/dataset_table_workflow/workflow.py @@ -4,6 +4,7 @@ from learnware.logger import get_module_logger from homo import HomogeneousDatasetWorkflow from hetero import HeterogeneousDatasetWorkflow from config import homo_table_benchmark_config, hetero_cross_feat_eng_benchmark_config, hetero_cross_task_benchmark_config +from utils import set_seed logger = get_module_logger("base_table", level="INFO") @@ -26,15 +27,17 @@ class TableDatasetWorkflow: workflow.labeled_homo_table_example() def cross_feat_eng_hetero_table_example(self): + set_seed(0) workflow = HeterogeneousDatasetWorkflow( benchmark_config=hetero_cross_feat_eng_benchmark_config, name="hetero", - rebuild=True, - retrain=True + rebuild=False, + retrain=False ) workflow.unlabeled_hetero_table_example() def cross_task_hetero_table_example(self): + set_seed(0) workflow = HeterogeneousDatasetWorkflow( benchmark_config=hetero_cross_task_benchmark_config, name="hetero",