|
|
|
@@ -117,7 +117,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
def labeled_hetero_table_example(self): |
|
|
|
logger.info("Total Items: %d" % len(self.market)) |
|
|
|
methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] |
|
|
|
methods_to_test = [] |
|
|
|
recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]} |
|
|
|
|
|
|
|
user = self.benchmark.name |
|
|
|
@@ -170,7 +169,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
for method_name in methods: |
|
|
|
logger.info(f"Testing method {method_name}") |
|
|
|
test_info["method_name"] = method_name |
|
|
|
test_info["force"] = method_name in methods_to_test |
|
|
|
test_info.update(method_configs[method_name]) |
|
|
|
self.test_method(test_info, recorders, loss_func=loss_func_rmse) |
|
|
|
|
|
|
|
|