diff --git a/examples/dataset_table_workflow/base.py b/examples/dataset_table_workflow/base.py index 93a66cd..1634e97 100644 --- a/examples/dataset_table_workflow/base.py +++ b/examples/dataset_table_workflow/base.py @@ -6,27 +6,17 @@ import requests import tempfile import traceback import numpy as np -from queue import Empty -from tqdm import tqdm from learnware.client import LearnwareClient from learnware.logger import get_module_logger from learnware.market import instantiate_learnware_market from learnware.reuse.utils import fill_data_with_mean from learnware.tests.benchmarks import LearnwareBenchmark -from torch.multiprocessing import Process, Queue, set_start_method from config import * from methods import * -from utils import process_single_aug logger = get_module_logger("base_table", level="INFO") -try: - set_start_method('spawn') -except RuntimeError: - pass -torch.multiprocessing.set_sharing_strategy("file_system") - class TableWorkflow: def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): @@ -99,16 +89,6 @@ class TableWorkflow: logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...") time.sleep(1) continue - - @staticmethod - def process_learnware_chunk(cuda_idx, method, test_info, loss_func, learnware_chunk, queue): - torch.cuda.set_device(cuda_idx) - for learnware in learnware_chunk: - learnware_index = test_info['learnwares'].index(learnware) - test_info['single_learnware'] = learnware - scores = TableWorkflow._limited_data(method, test_info, loss_func) - torch.cuda.empty_cache() - queue.put((learnware_index, scores)) def test_method(self, test_info, recorders, loss_func=loss_func_rmse): method_name_full = test_info["method_name"] @@ -121,51 +101,9 @@ class TableWorkflow: os.makedirs(save_root_path, exist_ok=True) save_path = os.path.join(save_root_path, f"{method_name}.json") - if method_name_full == "hetero_single_aug": - if recorder.should_test_method(user, idx, save_path): - # * single-process - # bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}") - # for learnware in test_info['learnwares']: - # test_info['single_learnware'] = learnware - # scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) - # recorder.record(user, idx, scores) - # bar.update(1) - - # * multi-process - queue = Queue() - processes = [] - bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}", unit="learnware") - learnware_chunks = [test_info["learnwares"][i:len(test_info["learnwares"]):len(self.cuda_idx)] for i in self.cuda_idx] - - for cuda_idx, learnware_chunk in zip(self.cuda_idx, learnware_chunks): - p = Process(target=TableWorkflow.process_learnware_chunk, args=(cuda_idx, method, test_info, loss_func, learnware_chunk, queue)) - processes.append(p) - p.start() - - all_results = [] - while any(p.is_alive() for p in processes) or not queue.empty(): - try: - result = queue.get(timeout=0.1) - all_results.append(result) - bar.update(1) - except Empty: - time.sleep(0.1) - continue - bar.close() - - for p in processes: - p.join() - - all_results.sort(key=lambda x: x[0]) - all_scores = [result[1] for result in all_results] - recorder.record(user, all_scores) - recorder.save(save_path) - - process_single_aug(user, idx, recorder.data[user][idx], recorders, save_root_path) - else: - if recorder.should_test_method(user, idx, save_path): - scores = self._limited_data(method, test_info, loss_func) - recorder.record(user, scores) - recorder.save(save_path) + if recorder.should_test_method(user, idx, save_path): + scores = self._limited_data(method, test_info, loss_func) + recorder.record(user, scores) + recorder.save(save_path) logger.info(f"Method {method_name} on {user}_{idx} finished") \ No newline at end of file diff --git a/examples/dataset_table_workflow/hetero.py b/examples/dataset_table_workflow/hetero.py index 21b87e9..b845b3b 100644 --- a/examples/dataset_table_workflow/hetero.py +++ b/examples/dataset_table_workflow/hetero.py @@ -109,7 +109,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): def labeled_hetero_table_example(self): logger.info("Total Items: %d" % len(self.market)) methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] - recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]} + recorders = {method: Recorder() for method in methods} user = self.benchmark.name for idx in range(self.benchmark.user_num): @@ -134,10 +134,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): search_result = self.market.search_learnware(user_info) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() - - rank_map = {item.learnware.id: index for index, item in enumerate(single_result)} - all_learnwares = self.market.get_learnwares() - all_learnwares.sort(key=lambda learnware: rank_map.get(learnware.id, float('inf'))) if len(multiple_result) > 0: mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) @@ -152,8 +148,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} method_configs = { "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, - "hetero_single_aug": {"user_rkme": user_stat_spec, "learnwares": all_learnwares}, - "hetero_multiple_aug": common_config, + "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware}, "hetero_multiple_avg": common_config, "hetero_ensemble_pruning": common_config } @@ -167,5 +162,4 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): for method, recorder in recorders.items(): recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) - methods_to_plot = ["user_model", "select_score", "mean_score", "hetero_multiple_avg", "hetero_ensemble_pruning"] - plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Hetero", n_labeled_list=hetero_n_labeled_list) \ No newline at end of file + plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list) \ No newline at end of file diff --git a/examples/dataset_table_workflow/utils.py b/examples/dataset_table_workflow/utils.py index 1e23e8e..730ef3a 100644 --- a/examples/dataset_table_workflow/utils.py +++ b/examples/dataset_table_workflow/utils.py @@ -37,31 +37,7 @@ class Recorder: self.load(path) return user not in self.data or idx > len(self.data[user]) - 1 return True - - -def process_single_aug(user, idx, scores, recorders, root_path): - try: - n_labeled = len(scores[0]) - select_scores, mean_scores, oracle_scores = [], [], [] - for i in range(n_labeled): - sub_scores_array = np.vstack([lst[i] for lst in scores]) - sub_scores_select = np.squeeze(sub_scores_array[0]) - sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) - sub_scores_min = np.squeeze(np.min(sub_scores_array, axis=0)) - - select_scores.append(sub_scores_select.tolist()) - mean_scores.append(sub_scores_mean.tolist()) - oracle_scores.append(sub_scores_min.tolist()) - - for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], - [select_scores, mean_scores, oracle_scores]): - recorders[method_name].record(user, scores) - save_path = os.path.join(root_path, f"{method_name}.json") - recorders[method_name].save(save_path) - except Exception: - error_message = traceback.format_exc() - logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}") - + def plot_performance_curves(path, user, recorders, task, n_labeled_list): plt.figure(figsize=(10, 6)) diff --git a/examples/dataset_table_workflow/workflow.py b/examples/dataset_table_workflow/workflow.py index 5dbe80d..73ebc86 100644 --- a/examples/dataset_table_workflow/workflow.py +++ b/examples/dataset_table_workflow/workflow.py @@ -38,8 +38,8 @@ class TableDatasetWorkflow: workflow = HeterogeneousDatasetWorkflow( benchmark_config=hetero_cross_task_benchmark_config, name="hetero", - rebuild=False, - retrain=False + rebuild=True, + retrain=True ) workflow.labeled_hetero_table_example()