Browse Source

[MNT] update hetero_single_aug

tags/v0.3.2
liuht 2 years ago
parent
commit
df2aaf0351
4 changed files with 10 additions and 102 deletions
  1. +4
    -66
      examples/dataset_table_workflow/base.py
  2. +3
    -9
      examples/dataset_table_workflow/hetero.py
  3. +1
    -25
      examples/dataset_table_workflow/utils.py
  4. +2
    -2
      examples/dataset_table_workflow/workflow.py

+ 4
- 66
examples/dataset_table_workflow/base.py View File

@@ -6,27 +6,17 @@ import requests
import tempfile
import traceback
import numpy as np
from queue import Empty
from tqdm import tqdm
from learnware.client import LearnwareClient
from learnware.logger import get_module_logger
from learnware.market import instantiate_learnware_market
from learnware.reuse.utils import fill_data_with_mean
from learnware.tests.benchmarks import LearnwareBenchmark
from torch.multiprocessing import Process, Queue, set_start_method

from config import *
from methods import *
from utils import process_single_aug

logger = get_module_logger("base_table", level="INFO")

try:
set_start_method('spawn')
except RuntimeError:
pass
torch.multiprocessing.set_sharing_strategy("file_system")


class TableWorkflow:
def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False):
@@ -99,16 +89,6 @@ class TableWorkflow:
logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...")
time.sleep(1)
continue
@staticmethod
def process_learnware_chunk(cuda_idx, method, test_info, loss_func, learnware_chunk, queue):
torch.cuda.set_device(cuda_idx)
for learnware in learnware_chunk:
learnware_index = test_info['learnwares'].index(learnware)
test_info['single_learnware'] = learnware
scores = TableWorkflow._limited_data(method, test_info, loss_func)
torch.cuda.empty_cache()
queue.put((learnware_index, scores))
def test_method(self, test_info, recorders, loss_func=loss_func_rmse):
method_name_full = test_info["method_name"]
@@ -121,51 +101,9 @@ class TableWorkflow:
os.makedirs(save_root_path, exist_ok=True)
save_path = os.path.join(save_root_path, f"{method_name}.json")
if method_name_full == "hetero_single_aug":
if recorder.should_test_method(user, idx, save_path):
# * single-process
# bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}")
# for learnware in test_info['learnwares']:
# test_info['single_learnware'] = learnware
# scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
# recorder.record(user, idx, scores)
# bar.update(1)
# * multi-process
queue = Queue()
processes = []
bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}", unit="learnware")
learnware_chunks = [test_info["learnwares"][i:len(test_info["learnwares"]):len(self.cuda_idx)] for i in self.cuda_idx]
for cuda_idx, learnware_chunk in zip(self.cuda_idx, learnware_chunks):
p = Process(target=TableWorkflow.process_learnware_chunk, args=(cuda_idx, method, test_info, loss_func, learnware_chunk, queue))
processes.append(p)
p.start()
all_results = []
while any(p.is_alive() for p in processes) or not queue.empty():
try:
result = queue.get(timeout=0.1)
all_results.append(result)
bar.update(1)
except Empty:
time.sleep(0.1)
continue
bar.close()

for p in processes:
p.join()
all_results.sort(key=lambda x: x[0])
all_scores = [result[1] for result in all_results]
recorder.record(user, all_scores)
recorder.save(save_path)
process_single_aug(user, idx, recorder.data[user][idx], recorders, save_root_path)
else:
if recorder.should_test_method(user, idx, save_path):
scores = self._limited_data(method, test_info, loss_func)
recorder.record(user, scores)
recorder.save(save_path)
if recorder.should_test_method(user, idx, save_path):
scores = self._limited_data(method, test_info, loss_func)
recorder.record(user, scores)
recorder.save(save_path)
logger.info(f"Method {method_name} on {user}_{idx} finished")

+ 3
- 9
examples/dataset_table_workflow/hetero.py View File

@@ -109,7 +109,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
def labeled_hetero_table_example(self):
logger.info("Total Items: %d" % len(self.market))
methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"]
recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]}
recorders = {method: Recorder() for method in methods}

user = self.benchmark.name
for idx in range(self.benchmark.user_num):
@@ -134,10 +134,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
search_result = self.market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
rank_map = {item.learnware.id: index for index, item in enumerate(single_result)}
all_learnwares = self.market.get_learnwares()
all_learnwares.sort(key=lambda learnware: rank_map.get(learnware.id, float('inf')))

if len(multiple_result) > 0:
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
@@ -152,8 +148,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list}
method_configs = {
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
"hetero_single_aug": {"user_rkme": user_stat_spec, "learnwares": all_learnwares},
"hetero_multiple_aug": common_config,
"hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware},
"hetero_multiple_avg": common_config,
"hetero_ensemble_pruning": common_config
}
@@ -167,5 +162,4 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
methods_to_plot = ["user_model", "select_score", "mean_score", "hetero_multiple_avg", "hetero_ensemble_pruning"]
plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Hetero", n_labeled_list=hetero_n_labeled_list)
plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list)

+ 1
- 25
examples/dataset_table_workflow/utils.py View File

@@ -37,31 +37,7 @@ class Recorder:
self.load(path)
return user not in self.data or idx > len(self.data[user]) - 1
return True


def process_single_aug(user, idx, scores, recorders, root_path):
try:
n_labeled = len(scores[0])
select_scores, mean_scores, oracle_scores = [], [], []
for i in range(n_labeled):
sub_scores_array = np.vstack([lst[i] for lst in scores])
sub_scores_select = np.squeeze(sub_scores_array[0])
sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
sub_scores_min = np.squeeze(np.min(sub_scores_array, axis=0))
select_scores.append(sub_scores_select.tolist())
mean_scores.append(sub_scores_mean.tolist())
oracle_scores.append(sub_scores_min.tolist())

for method_name, scores in zip(["select_score", "mean_score", "oracle_score"],
[select_scores, mean_scores, oracle_scores]):
recorders[method_name].record(user, scores)
save_path = os.path.join(root_path, f"{method_name}.json")
recorders[method_name].save(save_path)
except Exception:
error_message = traceback.format_exc()
logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}")


def plot_performance_curves(path, user, recorders, task, n_labeled_list):
plt.figure(figsize=(10, 6))


+ 2
- 2
examples/dataset_table_workflow/workflow.py View File

@@ -38,8 +38,8 @@ class TableDatasetWorkflow:
workflow = HeterogeneousDatasetWorkflow(
benchmark_config=hetero_cross_task_benchmark_config,
name="hetero",
rebuild=False,
retrain=False
rebuild=True,
retrain=True
)
workflow.labeled_hetero_table_example()



Loading…
Cancel
Save