@@ -107,7 +107,7 @@ class HomogeneousDatasetWorkflow(TableWorkflow):
def labeled_homo_table_example(self):
logger.info("Total Item: %d" % (len(self.market)))
methods = ["user_model", "homo_single_aug", "homo_multiple_avg", "homo_ ensemble_pruning"]
methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"]
methods_to_retest = []
recorders = {method: Recorder() for method in methods}
@@ -147,8 +147,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow):
method_configs = {
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
"homo_single_aug": {"single_learnware": [single_result[0].learnware]},
"homo_multiple_aug": common_config,
"homo_multiple_avg": common_config,
"homo_ensemble_pruning": common_config
}
@@ -162,5 +160,5 @@ class HomogeneousDatasetWorkflow(TableWorkflow):
for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
methods_to_plot = ["user_model", "homo_single_aug", "homo_multiple_avg", "homo_ ensemble_pruning"]
methods_to_plot = ["user_model", "homo_single_aug", "homo_ensemble_pruning"]
plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Homo", n_labeled_list=homo_n_labeled_list)