diff --git a/examples/dataset_table_workflow/homo.py b/examples/dataset_table_workflow/homo.py index 92305f5..71445ec 100644 --- a/examples/dataset_table_workflow/homo.py +++ b/examples/dataset_table_workflow/homo.py @@ -105,22 +105,13 @@ class HomogeneousDatasetWorkflow(TableWorkflow): ) - def labeled_homo_table_example(self, skip_test=False): + def labeled_homo_table_example(self, skip_test=True): logger.info("Total Item: %d" % (len(self.market))) methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"] methods_to_retest = [] recorders = {method: Recorder() for method in methods} user = self.benchmark.name - if not skip_test: - for idx in range(self.benchmark.user_num): - test_x, test_y = self.benchmark.get_test_data(user_ids=idx) - test_x, test_y = test_x.values, test_y.values - - train_x, train_y = self.benchmark.get_train_data(user_ids=idx) - train_x, train_y = train_x.values, train_y.values - train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) - if not skip_test: for idx in range(self.benchmark.user_num): test_x, test_y = self.benchmark.get_test_data(user_ids=idx) @@ -130,10 +121,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow): train_x, train_y = train_x.values, train_y.values train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) - user_stat_spec = generate_stat_spec(type="table", X=test_x) - user_info = BaseUserInfo( - semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec} - ) logger.info(f"Searching Market for user: {user}_{idx}") user_stat_spec = generate_stat_spec(type="table", X=test_x) user_info = BaseUserInfo( @@ -141,28 +128,15 @@ class HomogeneousDatasetWorkflow(TableWorkflow): ) logger.info(f"Searching Market for user: {user}_{idx}") - search_result = self.market.search_learnware(user_info) - single_result = search_result.get_single_results() - multiple_result = search_result.get_multiple_results() search_result = self.market.search_learnware(user_info) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() - logger.info(f"search result of user {user}_{idx}:") - logger.info( - f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" - ) logger.info(f"search result of user {user}_{idx}:") logger.info( f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" ) - if len(multiple_result) > 0: - mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) - logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") - mixture_learnware_list = multiple_result[0].learnwares - else: - mixture_learnware_list = [single_result[0].learnware] if len(multiple_result) > 0: mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") @@ -177,13 +151,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow): "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, "homo_ensemble_pruning": common_config } - test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y} - common_config = {"learnwares": mixture_learnware_list} - method_configs = { - "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, - "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, - "homo_ensemble_pruning": common_config - } for method_name in methods: logger.info(f"Testing method {method_name}") @@ -194,15 +161,5 @@ class HomogeneousDatasetWorkflow(TableWorkflow): for method, recorder in recorders.items(): recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) - for method_name in methods: - logger.info(f"Testing method {method_name}") - test_info["method_name"] = method_name - test_info["force"] = method_name in methods_to_retest - test_info.update(method_configs[method_name]) - self.test_method(test_info, recorders, loss_func=loss_func_rmse) - - for method, recorder in recorders.items(): - recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) - - plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) + plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) \ No newline at end of file