|
|
|
@@ -108,7 +108,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
def labeled_homo_table_example(self, skip_test=False): |
|
|
|
logger.info("Total Item: %d" % (len(self.market))) |
|
|
|
methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"] |
|
|
|
methods_to_retest = [] |
|
|
|
recorders = {method: Recorder() for method in methods} |
|
|
|
user = self.benchmark.name |
|
|
|
|
|
|
|
@@ -116,16 +115,7 @@ class HomogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
for idx in range(self.benchmark.user_num): |
|
|
|
test_x, test_y = self.benchmark.get_test_data(user_ids=idx) |
|
|
|
test_x, test_y = test_x.values, test_y.values |
|
|
|
|
|
|
|
train_x, train_y = self.benchmark.get_train_data(user_ids=idx) |
|
|
|
train_x, train_y = train_x.values, train_y.values |
|
|
|
train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) |
|
|
|
|
|
|
|
if not skip_test: |
|
|
|
for idx in range(self.benchmark.user_num): |
|
|
|
test_x, test_y = self.benchmark.get_test_data(user_ids=idx) |
|
|
|
test_x, test_y = test_x.values, test_y.values |
|
|
|
|
|
|
|
|
|
|
|
train_x, train_y = self.benchmark.get_train_data(user_ids=idx) |
|
|
|
train_x, train_y = train_x.values, train_y.values |
|
|
|
train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) |
|
|
|
@@ -134,35 +124,17 @@ class HomogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
user_info = BaseUserInfo( |
|
|
|
semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec} |
|
|
|
) |
|
|
|
|
|
|
|
logger.info(f"Searching Market for user: {user}_{idx}") |
|
|
|
user_stat_spec = generate_stat_spec(type="table", X=test_x) |
|
|
|
user_info = BaseUserInfo( |
|
|
|
semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec} |
|
|
|
) |
|
|
|
logger.info(f"Searching Market for user: {user}_{idx}") |
|
|
|
|
|
|
|
search_result = self.market.search_learnware(user_info) |
|
|
|
single_result = search_result.get_single_results() |
|
|
|
multiple_result = search_result.get_multiple_results() |
|
|
|
search_result = self.market.search_learnware(user_info) |
|
|
|
single_result = search_result.get_single_results() |
|
|
|
multiple_result = search_result.get_multiple_results() |
|
|
|
|
|
|
|
logger.info(f"search result of user {user}_{idx}:") |
|
|
|
logger.info( |
|
|
|
f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" |
|
|
|
) |
|
|
|
logger.info(f"search result of user {user}_{idx}:") |
|
|
|
logger.info( |
|
|
|
f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" |
|
|
|
) |
|
|
|
|
|
|
|
if len(multiple_result) > 0: |
|
|
|
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) |
|
|
|
logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") |
|
|
|
mixture_learnware_list = multiple_result[0].learnwares |
|
|
|
else: |
|
|
|
mixture_learnware_list = [single_result[0].learnware] |
|
|
|
if len(multiple_result) > 0: |
|
|
|
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) |
|
|
|
logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") |
|
|
|
@@ -177,32 +149,14 @@ class HomogeneousDatasetWorkflow(TableWorkflow): |
|
|
|
"homo_single_aug": {"single_learnware": [single_result[0].learnware]}, |
|
|
|
"homo_ensemble_pruning": common_config |
|
|
|
} |
|
|
|
test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y} |
|
|
|
common_config = {"learnwares": mixture_learnware_list} |
|
|
|
method_configs = { |
|
|
|
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, |
|
|
|
"homo_single_aug": {"single_learnware": [single_result[0].learnware]}, |
|
|
|
"homo_ensemble_pruning": common_config |
|
|
|
} |
|
|
|
|
|
|
|
for method_name in methods: |
|
|
|
logger.info(f"Testing method {method_name}") |
|
|
|
test_info["method_name"] = method_name |
|
|
|
test_info["force"] = method_name in methods_to_retest |
|
|
|
test_info.update(method_configs[method_name]) |
|
|
|
self.test_method(test_info, recorders, loss_func=loss_func_rmse) |
|
|
|
|
|
|
|
for method, recorder in recorders.items(): |
|
|
|
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) |
|
|
|
for method_name in methods: |
|
|
|
logger.info(f"Testing method {method_name}") |
|
|
|
test_info["method_name"] = method_name |
|
|
|
test_info["force"] = method_name in methods_to_retest |
|
|
|
test_info.update(method_configs[method_name]) |
|
|
|
self.test_method(test_info, recorders, loss_func=loss_func_rmse) |
|
|
|
|
|
|
|
|
|
|
|
for method, recorder in recorders.items(): |
|
|
|
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) |
|
|
|
|
|
|
|
plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) |
|
|
|
|
|
|
|
plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) |