| @@ -6,27 +6,17 @@ import requests | |||||
| import tempfile | import tempfile | ||||
| import traceback | import traceback | ||||
| import numpy as np | import numpy as np | ||||
| from queue import Empty | |||||
| from tqdm import tqdm | |||||
| from learnware.client import LearnwareClient | from learnware.client import LearnwareClient | ||||
| from learnware.logger import get_module_logger | from learnware.logger import get_module_logger | ||||
| from learnware.market import instantiate_learnware_market | from learnware.market import instantiate_learnware_market | ||||
| from learnware.reuse.utils import fill_data_with_mean | from learnware.reuse.utils import fill_data_with_mean | ||||
| from learnware.tests.benchmarks import LearnwareBenchmark | from learnware.tests.benchmarks import LearnwareBenchmark | ||||
| from torch.multiprocessing import Process, Queue, set_start_method | |||||
| from config import * | from config import * | ||||
| from methods import * | from methods import * | ||||
| from utils import process_single_aug | |||||
| logger = get_module_logger("base_table", level="INFO") | logger = get_module_logger("base_table", level="INFO") | ||||
| try: | |||||
| set_start_method('spawn') | |||||
| except RuntimeError: | |||||
| pass | |||||
| torch.multiprocessing.set_sharing_strategy("file_system") | |||||
| class TableWorkflow: | class TableWorkflow: | ||||
| def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): | def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): | ||||
| @@ -99,16 +89,6 @@ class TableWorkflow: | |||||
| logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...") | logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...") | ||||
| time.sleep(1) | time.sleep(1) | ||||
| continue | continue | ||||
| @staticmethod | |||||
| def process_learnware_chunk(cuda_idx, method, test_info, loss_func, learnware_chunk, queue): | |||||
| torch.cuda.set_device(cuda_idx) | |||||
| for learnware in learnware_chunk: | |||||
| learnware_index = test_info['learnwares'].index(learnware) | |||||
| test_info['single_learnware'] = learnware | |||||
| scores = TableWorkflow._limited_data(method, test_info, loss_func) | |||||
| torch.cuda.empty_cache() | |||||
| queue.put((learnware_index, scores)) | |||||
| def test_method(self, test_info, recorders, loss_func=loss_func_rmse): | def test_method(self, test_info, recorders, loss_func=loss_func_rmse): | ||||
| method_name_full = test_info["method_name"] | method_name_full = test_info["method_name"] | ||||
| @@ -121,51 +101,9 @@ class TableWorkflow: | |||||
| os.makedirs(save_root_path, exist_ok=True) | os.makedirs(save_root_path, exist_ok=True) | ||||
| save_path = os.path.join(save_root_path, f"{method_name}.json") | save_path = os.path.join(save_root_path, f"{method_name}.json") | ||||
| if method_name_full == "hetero_single_aug": | |||||
| if recorder.should_test_method(user, idx, save_path): | |||||
| # * single-process | |||||
| # bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}") | |||||
| # for learnware in test_info['learnwares']: | |||||
| # test_info['single_learnware'] = learnware | |||||
| # scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | |||||
| # recorder.record(user, idx, scores) | |||||
| # bar.update(1) | |||||
| # * multi-process | |||||
| queue = Queue() | |||||
| processes = [] | |||||
| bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}", unit="learnware") | |||||
| learnware_chunks = [test_info["learnwares"][i:len(test_info["learnwares"]):len(self.cuda_idx)] for i in self.cuda_idx] | |||||
| for cuda_idx, learnware_chunk in zip(self.cuda_idx, learnware_chunks): | |||||
| p = Process(target=TableWorkflow.process_learnware_chunk, args=(cuda_idx, method, test_info, loss_func, learnware_chunk, queue)) | |||||
| processes.append(p) | |||||
| p.start() | |||||
| all_results = [] | |||||
| while any(p.is_alive() for p in processes) or not queue.empty(): | |||||
| try: | |||||
| result = queue.get(timeout=0.1) | |||||
| all_results.append(result) | |||||
| bar.update(1) | |||||
| except Empty: | |||||
| time.sleep(0.1) | |||||
| continue | |||||
| bar.close() | |||||
| for p in processes: | |||||
| p.join() | |||||
| all_results.sort(key=lambda x: x[0]) | |||||
| all_scores = [result[1] for result in all_results] | |||||
| recorder.record(user, all_scores) | |||||
| recorder.save(save_path) | |||||
| process_single_aug(user, idx, recorder.data[user][idx], recorders, save_root_path) | |||||
| else: | |||||
| if recorder.should_test_method(user, idx, save_path): | |||||
| scores = self._limited_data(method, test_info, loss_func) | |||||
| recorder.record(user, scores) | |||||
| recorder.save(save_path) | |||||
| if recorder.should_test_method(user, idx, save_path): | |||||
| scores = self._limited_data(method, test_info, loss_func) | |||||
| recorder.record(user, scores) | |||||
| recorder.save(save_path) | |||||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | logger.info(f"Method {method_name} on {user}_{idx} finished") | ||||
| @@ -109,7 +109,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||||
| def labeled_hetero_table_example(self): | def labeled_hetero_table_example(self): | ||||
| logger.info("Total Items: %d" % len(self.market)) | logger.info("Total Items: %d" % len(self.market)) | ||||
| methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] | methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] | ||||
| recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]} | |||||
| recorders = {method: Recorder() for method in methods} | |||||
| user = self.benchmark.name | user = self.benchmark.name | ||||
| for idx in range(self.benchmark.user_num): | for idx in range(self.benchmark.user_num): | ||||
| @@ -134,10 +134,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||||
| search_result = self.market.search_learnware(user_info) | search_result = self.market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| multiple_result = search_result.get_multiple_results() | multiple_result = search_result.get_multiple_results() | ||||
| rank_map = {item.learnware.id: index for index, item in enumerate(single_result)} | |||||
| all_learnwares = self.market.get_learnwares() | |||||
| all_learnwares.sort(key=lambda learnware: rank_map.get(learnware.id, float('inf'))) | |||||
| if len(multiple_result) > 0: | if len(multiple_result) > 0: | ||||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | ||||
| @@ -152,8 +148,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||||
| common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} | common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} | ||||
| method_configs = { | method_configs = { | ||||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | ||||
| "hetero_single_aug": {"user_rkme": user_stat_spec, "learnwares": all_learnwares}, | |||||
| "hetero_multiple_aug": common_config, | |||||
| "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware}, | |||||
| "hetero_multiple_avg": common_config, | "hetero_multiple_avg": common_config, | ||||
| "hetero_ensemble_pruning": common_config | "hetero_ensemble_pruning": common_config | ||||
| } | } | ||||
| @@ -167,5 +162,4 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||||
| for method, recorder in recorders.items(): | for method, recorder in recorders.items(): | ||||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | ||||
| methods_to_plot = ["user_model", "select_score", "mean_score", "hetero_multiple_avg", "hetero_ensemble_pruning"] | |||||
| plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Hetero", n_labeled_list=hetero_n_labeled_list) | |||||
| plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list) | |||||
| @@ -37,31 +37,7 @@ class Recorder: | |||||
| self.load(path) | self.load(path) | ||||
| return user not in self.data or idx > len(self.data[user]) - 1 | return user not in self.data or idx > len(self.data[user]) - 1 | ||||
| return True | return True | ||||
| def process_single_aug(user, idx, scores, recorders, root_path): | |||||
| try: | |||||
| n_labeled = len(scores[0]) | |||||
| select_scores, mean_scores, oracle_scores = [], [], [] | |||||
| for i in range(n_labeled): | |||||
| sub_scores_array = np.vstack([lst[i] for lst in scores]) | |||||
| sub_scores_select = np.squeeze(sub_scores_array[0]) | |||||
| sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) | |||||
| sub_scores_min = np.squeeze(np.min(sub_scores_array, axis=0)) | |||||
| select_scores.append(sub_scores_select.tolist()) | |||||
| mean_scores.append(sub_scores_mean.tolist()) | |||||
| oracle_scores.append(sub_scores_min.tolist()) | |||||
| for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], | |||||
| [select_scores, mean_scores, oracle_scores]): | |||||
| recorders[method_name].record(user, scores) | |||||
| save_path = os.path.join(root_path, f"{method_name}.json") | |||||
| recorders[method_name].save(save_path) | |||||
| except Exception: | |||||
| error_message = traceback.format_exc() | |||||
| logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}") | |||||
| def plot_performance_curves(path, user, recorders, task, n_labeled_list): | def plot_performance_curves(path, user, recorders, task, n_labeled_list): | ||||
| plt.figure(figsize=(10, 6)) | plt.figure(figsize=(10, 6)) | ||||
| @@ -38,8 +38,8 @@ class TableDatasetWorkflow: | |||||
| workflow = HeterogeneousDatasetWorkflow( | workflow = HeterogeneousDatasetWorkflow( | ||||
| benchmark_config=hetero_cross_task_benchmark_config, | benchmark_config=hetero_cross_task_benchmark_config, | ||||
| name="hetero", | name="hetero", | ||||
| rebuild=False, | |||||
| retrain=False | |||||
| rebuild=True, | |||||
| retrain=True | |||||
| ) | ) | ||||
| workflow.labeled_hetero_table_example() | workflow.labeled_hetero_table_example() | ||||