| @@ -6,27 +6,17 @@ import requests | |||
| import tempfile | |||
| import traceback | |||
| import numpy as np | |||
| from queue import Empty | |||
| from tqdm import tqdm | |||
| from learnware.client import LearnwareClient | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import instantiate_learnware_market | |||
| from learnware.reuse.utils import fill_data_with_mean | |||
| from learnware.tests.benchmarks import LearnwareBenchmark | |||
| from torch.multiprocessing import Process, Queue, set_start_method | |||
| from config import * | |||
| from methods import * | |||
| from utils import process_single_aug | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| try: | |||
| set_start_method('spawn') | |||
| except RuntimeError: | |||
| pass | |||
| torch.multiprocessing.set_sharing_strategy("file_system") | |||
| class TableWorkflow: | |||
| def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): | |||
| @@ -99,16 +89,6 @@ class TableWorkflow: | |||
| logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...") | |||
| time.sleep(1) | |||
| continue | |||
| @staticmethod | |||
| def process_learnware_chunk(cuda_idx, method, test_info, loss_func, learnware_chunk, queue): | |||
| torch.cuda.set_device(cuda_idx) | |||
| for learnware in learnware_chunk: | |||
| learnware_index = test_info['learnwares'].index(learnware) | |||
| test_info['single_learnware'] = learnware | |||
| scores = TableWorkflow._limited_data(method, test_info, loss_func) | |||
| torch.cuda.empty_cache() | |||
| queue.put((learnware_index, scores)) | |||
| def test_method(self, test_info, recorders, loss_func=loss_func_rmse): | |||
| method_name_full = test_info["method_name"] | |||
| @@ -121,51 +101,9 @@ class TableWorkflow: | |||
| os.makedirs(save_root_path, exist_ok=True) | |||
| save_path = os.path.join(save_root_path, f"{method_name}.json") | |||
| if method_name_full == "hetero_single_aug": | |||
| if recorder.should_test_method(user, idx, save_path): | |||
| # * single-process | |||
| # bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}") | |||
| # for learnware in test_info['learnwares']: | |||
| # test_info['single_learnware'] = learnware | |||
| # scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | |||
| # recorder.record(user, idx, scores) | |||
| # bar.update(1) | |||
| # * multi-process | |||
| queue = Queue() | |||
| processes = [] | |||
| bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}", unit="learnware") | |||
| learnware_chunks = [test_info["learnwares"][i:len(test_info["learnwares"]):len(self.cuda_idx)] for i in self.cuda_idx] | |||
| for cuda_idx, learnware_chunk in zip(self.cuda_idx, learnware_chunks): | |||
| p = Process(target=TableWorkflow.process_learnware_chunk, args=(cuda_idx, method, test_info, loss_func, learnware_chunk, queue)) | |||
| processes.append(p) | |||
| p.start() | |||
| all_results = [] | |||
| while any(p.is_alive() for p in processes) or not queue.empty(): | |||
| try: | |||
| result = queue.get(timeout=0.1) | |||
| all_results.append(result) | |||
| bar.update(1) | |||
| except Empty: | |||
| time.sleep(0.1) | |||
| continue | |||
| bar.close() | |||
| for p in processes: | |||
| p.join() | |||
| all_results.sort(key=lambda x: x[0]) | |||
| all_scores = [result[1] for result in all_results] | |||
| recorder.record(user, all_scores) | |||
| recorder.save(save_path) | |||
| process_single_aug(user, idx, recorder.data[user][idx], recorders, save_root_path) | |||
| else: | |||
| if recorder.should_test_method(user, idx, save_path): | |||
| scores = self._limited_data(method, test_info, loss_func) | |||
| recorder.record(user, scores) | |||
| recorder.save(save_path) | |||
| if recorder.should_test_method(user, idx, save_path): | |||
| scores = self._limited_data(method, test_info, loss_func) | |||
| recorder.record(user, scores) | |||
| recorder.save(save_path) | |||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | |||
| @@ -109,7 +109,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| def labeled_hetero_table_example(self): | |||
| logger.info("Total Items: %d" % len(self.market)) | |||
| methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] | |||
| recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]} | |||
| recorders = {method: Recorder() for method in methods} | |||
| user = self.benchmark.name | |||
| for idx in range(self.benchmark.user_num): | |||
| @@ -134,10 +134,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| search_result = self.market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| rank_map = {item.learnware.id: index for index, item in enumerate(single_result)} | |||
| all_learnwares = self.market.get_learnwares() | |||
| all_learnwares.sort(key=lambda learnware: rank_map.get(learnware.id, float('inf'))) | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| @@ -152,8 +148,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} | |||
| method_configs = { | |||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | |||
| "hetero_single_aug": {"user_rkme": user_stat_spec, "learnwares": all_learnwares}, | |||
| "hetero_multiple_aug": common_config, | |||
| "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware}, | |||
| "hetero_multiple_avg": common_config, | |||
| "hetero_ensemble_pruning": common_config | |||
| } | |||
| @@ -167,5 +162,4 @@ class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| for method, recorder in recorders.items(): | |||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | |||
| methods_to_plot = ["user_model", "select_score", "mean_score", "hetero_multiple_avg", "hetero_ensemble_pruning"] | |||
| plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Hetero", n_labeled_list=hetero_n_labeled_list) | |||
| plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list) | |||
| @@ -37,31 +37,7 @@ class Recorder: | |||
| self.load(path) | |||
| return user not in self.data or idx > len(self.data[user]) - 1 | |||
| return True | |||
| def process_single_aug(user, idx, scores, recorders, root_path): | |||
| try: | |||
| n_labeled = len(scores[0]) | |||
| select_scores, mean_scores, oracle_scores = [], [], [] | |||
| for i in range(n_labeled): | |||
| sub_scores_array = np.vstack([lst[i] for lst in scores]) | |||
| sub_scores_select = np.squeeze(sub_scores_array[0]) | |||
| sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) | |||
| sub_scores_min = np.squeeze(np.min(sub_scores_array, axis=0)) | |||
| select_scores.append(sub_scores_select.tolist()) | |||
| mean_scores.append(sub_scores_mean.tolist()) | |||
| oracle_scores.append(sub_scores_min.tolist()) | |||
| for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], | |||
| [select_scores, mean_scores, oracle_scores]): | |||
| recorders[method_name].record(user, scores) | |||
| save_path = os.path.join(root_path, f"{method_name}.json") | |||
| recorders[method_name].save(save_path) | |||
| except Exception: | |||
| error_message = traceback.format_exc() | |||
| logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}") | |||
| def plot_performance_curves(path, user, recorders, task, n_labeled_list): | |||
| plt.figure(figsize=(10, 6)) | |||
| @@ -38,8 +38,8 @@ class TableDatasetWorkflow: | |||
| workflow = HeterogeneousDatasetWorkflow( | |||
| benchmark_config=hetero_cross_task_benchmark_config, | |||
| name="hetero", | |||
| rebuild=False, | |||
| retrain=False | |||
| rebuild=True, | |||
| retrain=True | |||
| ) | |||
| workflow.labeled_hetero_table_example() | |||