| @@ -19,9 +19,6 @@ from utils import set_seed | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| # for quick test only | |||
| from learnware.market.heterogeneous import HeteroMapTableOrganizer | |||
| class TableWorkflow: | |||
| def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): | |||
| self.root_path = os.path.abspath(os.path.join(__file__, "..")) | |||
| @@ -29,10 +26,7 @@ class TableWorkflow: | |||
| self.curves_result_path = os.path.join(self.result_path, "curves") | |||
| os.makedirs(self.result_path, exist_ok=True) | |||
| os.makedirs(self.curves_result_path, exist_ok=True) | |||
| # if name == "hetero": | |||
| # set_seed(42) | |||
| self._prepare_market(benchmark_config, name, rebuild, retrain) | |||
| self._prepare_market(benchmark_config, name, rebuild, retrain) | |||
| @staticmethod | |||
| def _limited_data(method, test_info, loss_func): | |||
| @@ -81,14 +75,6 @@ class TableWorkflow: | |||
| self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0]) | |||
| self.user_semantic["Name"]["Values"] = "" | |||
| # if retrain == True and rebuild == False: | |||
| # logger.info(f"training learnwares: {self.market.get_learnware_ids()[::-1]}") | |||
| # market_mapping = HeteroMapTableOrganizer.train(self.market.get_learnwares()[::-1], save_dir='test_model.bin', **market_mapping_params) | |||
| # self.market.learnware_organizer.market_mapping = market_mapping | |||
| # self.market.learnware_organizer._update_learnware_hetero_spec(self.market.get_learnware_ids()[::-1]) | |||
| # return | |||
| if len(self.market) == 0 or rebuild == True: | |||
| for learnware_id in self.benchmark.learnware_ids: | |||
| with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir: | |||
| @@ -216,7 +216,7 @@ hetero_cross_feat_eng_benchmark_config = BenchmarkConfig( | |||
| "00000912" | |||
| ], | |||
| test_data_path="PFS/test_data.zip", | |||
| # train_data_path="PFS/train_data.zip", | |||
| train_data_path="PFS/train_data.zip", | |||
| extra_info_path="PFS/extra_info.zip" | |||
| ) | |||
| @@ -39,7 +39,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| ) | |||
| logger.info(f"Searching Market for user: {user}_{idx}") | |||
| search_result = self.market.search_learnware(user_info) | |||
| search_result = self.market.search_learnware(user_info, max_search_num=10) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| @@ -53,15 +53,15 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| pred_y = single_hetero_learnware.predict(test_x) | |||
| single_score_list.append(loss_func_rmse(pred_y, test_y)) | |||
| # rmse_list = [] | |||
| # for learnware in all_learnwares: | |||
| # hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) | |||
| # hetero_learnware.align(user_rkme=user_stat_spec) | |||
| # pred_y = hetero_learnware.predict(test_x) | |||
| # rmse_list.append(loss_func_rmse(pred_y, test_y)) | |||
| # logger.info( | |||
| # f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}" | |||
| # ) | |||
| rmse_list = [] | |||
| for learnware in all_learnwares: | |||
| hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) | |||
| hetero_learnware.align(user_rkme=user_stat_spec) | |||
| pred_y = hetero_learnware.predict(test_x) | |||
| rmse_list.append(loss_func_rmse(pred_y, test_y)) | |||
| logger.info( | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}" | |||
| ) | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| @@ -83,21 +83,21 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| ensemble_score_list.append(ensemble_score) | |||
| logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}") | |||
| # learnware_rmse_list.append(rmse_list) | |||
| learnware_rmse_list.append(rmse_list) | |||
| # single_list = np.array(learnware_rmse_list) | |||
| # avg_score_list = [np.mean(lst, axis=0) for lst in single_list] | |||
| # oracle_score_list = [np.min(lst, axis=0) for lst in single_list] | |||
| single_list = np.array(learnware_rmse_list) | |||
| avg_score_list = [np.mean(lst, axis=0) for lst in single_list] | |||
| oracle_score_list = [np.min(lst, axis=0) for lst in single_list] | |||
| logger.info( | |||
| "RMSE of selected learnware: %.3f +/- %.3f" # , Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" | |||
| "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" | |||
| % ( | |||
| np.mean(single_score_list), | |||
| np.std(single_score_list), | |||
| # np.mean(avg_score_list), | |||
| # np.std(avg_score_list), | |||
| # np.mean(oracle_score_list), | |||
| # np.std(oracle_score_list), | |||
| np.mean(avg_score_list), | |||
| np.std(avg_score_list), | |||
| np.mean(oracle_score_list), | |||
| np.std(oracle_score_list) | |||
| ) | |||
| ) | |||
| logger.info( | |||
| @@ -4,6 +4,7 @@ from learnware.logger import get_module_logger | |||
| from homo import HomogeneousDatasetWorkflow | |||
| from hetero import HeterogeneousDatasetWorkflow | |||
| from config import homo_table_benchmark_config, hetero_cross_feat_eng_benchmark_config, hetero_cross_task_benchmark_config | |||
| from utils import set_seed | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| @@ -26,15 +27,17 @@ class TableDatasetWorkflow: | |||
| workflow.labeled_homo_table_example() | |||
| def cross_feat_eng_hetero_table_example(self): | |||
| set_seed(0) | |||
| workflow = HeterogeneousDatasetWorkflow( | |||
| benchmark_config=hetero_cross_feat_eng_benchmark_config, | |||
| name="hetero", | |||
| rebuild=True, | |||
| retrain=True | |||
| rebuild=False, | |||
| retrain=False | |||
| ) | |||
| workflow.unlabeled_hetero_table_example() | |||
| def cross_task_hetero_table_example(self): | |||
| set_seed(0) | |||
| workflow = HeterogeneousDatasetWorkflow( | |||
| benchmark_config=hetero_cross_task_benchmark_config, | |||
| name="hetero", | |||