From ed0c4bd7fd5f04b8d7ee64a149509a70917f045d Mon Sep 17 00:00:00 2001 From: liuht Date: Fri, 5 Jan 2024 21:45:24 +0800 Subject: [PATCH] [FIX] delete must test --- examples/dataset_table_workflow/base.py | 2 +- examples/dataset_table_workflow/hetero.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dataset_table_workflow/base.py b/examples/dataset_table_workflow/base.py index 1a1229e..3bcb7a8 100644 --- a/examples/dataset_table_workflow/base.py +++ b/examples/dataset_table_workflow/base.py @@ -111,7 +111,7 @@ class TableWorkflow: save_path = os.path.join(save_root_path, f"{method_name}.json") if method_name_full == "hetero_single_aug": - if test_info["force"] or recorder.should_test_method(user, idx, save_path): + if recorder.should_test_method(user, idx, save_path): # * multi-process queue = Queue() processes = [] diff --git a/examples/dataset_table_workflow/hetero.py b/examples/dataset_table_workflow/hetero.py index f365971..2f00cb9 100644 --- a/examples/dataset_table_workflow/hetero.py +++ b/examples/dataset_table_workflow/hetero.py @@ -117,7 +117,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): def labeled_hetero_table_example(self): logger.info("Total Items: %d" % len(self.market)) methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] - methods_to_test = [] recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]} user = self.benchmark.name @@ -170,7 +169,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): for method_name in methods: logger.info(f"Testing method {method_name}") test_info["method_name"] = method_name - test_info["force"] = method_name in methods_to_test test_info.update(method_configs[method_name]) self.test_method(test_info, recorders, loss_func=loss_func_rmse)