Browse Source

[MNT] update hetero_single_aug

tags/v0.3.2
liuht 2 years ago
parent
commit
df2aaf0351
4 changed files with 10 additions and 102 deletions
  1. +4
    -66
      examples/dataset_table_workflow/base.py
  2. +3
    -9
      examples/dataset_table_workflow/hetero.py
  3. +1
    -25
      examples/dataset_table_workflow/utils.py
  4. +2
    -2
      examples/dataset_table_workflow/workflow.py

+ 4
- 66
examples/dataset_table_workflow/base.py View File

@@ -6,27 +6,17 @@ import requests
import tempfile import tempfile
import traceback import traceback
import numpy as np import numpy as np
from queue import Empty
from tqdm import tqdm
from learnware.client import LearnwareClient from learnware.client import LearnwareClient
from learnware.logger import get_module_logger from learnware.logger import get_module_logger
from learnware.market import instantiate_learnware_market from learnware.market import instantiate_learnware_market
from learnware.reuse.utils import fill_data_with_mean from learnware.reuse.utils import fill_data_with_mean
from learnware.tests.benchmarks import LearnwareBenchmark from learnware.tests.benchmarks import LearnwareBenchmark
from torch.multiprocessing import Process, Queue, set_start_method


from config import * from config import *
from methods import * from methods import *
from utils import process_single_aug


logger = get_module_logger("base_table", level="INFO") logger = get_module_logger("base_table", level="INFO")


try:
set_start_method('spawn')
except RuntimeError:
pass
torch.multiprocessing.set_sharing_strategy("file_system")



class TableWorkflow: class TableWorkflow:
def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False):
@@ -99,16 +89,6 @@ class TableWorkflow:
logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...") logger.info(f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying...")
time.sleep(1) time.sleep(1)
continue continue
@staticmethod
def process_learnware_chunk(cuda_idx, method, test_info, loss_func, learnware_chunk, queue):
torch.cuda.set_device(cuda_idx)
for learnware in learnware_chunk:
learnware_index = test_info['learnwares'].index(learnware)
test_info['single_learnware'] = learnware
scores = TableWorkflow._limited_data(method, test_info, loss_func)
torch.cuda.empty_cache()
queue.put((learnware_index, scores))
def test_method(self, test_info, recorders, loss_func=loss_func_rmse): def test_method(self, test_info, recorders, loss_func=loss_func_rmse):
method_name_full = test_info["method_name"] method_name_full = test_info["method_name"]
@@ -121,51 +101,9 @@ class TableWorkflow:
os.makedirs(save_root_path, exist_ok=True) os.makedirs(save_root_path, exist_ok=True)
save_path = os.path.join(save_root_path, f"{method_name}.json") save_path = os.path.join(save_root_path, f"{method_name}.json")
if method_name_full == "hetero_single_aug":
if recorder.should_test_method(user, idx, save_path):
# * single-process
# bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}")
# for learnware in test_info['learnwares']:
# test_info['single_learnware'] = learnware
# scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
# recorder.record(user, idx, scores)
# bar.update(1)
# * multi-process
queue = Queue()
processes = []
bar = tqdm(total=len(test_info["learnwares"]), desc=f"Test {method_name}", unit="learnware")
learnware_chunks = [test_info["learnwares"][i:len(test_info["learnwares"]):len(self.cuda_idx)] for i in self.cuda_idx]
for cuda_idx, learnware_chunk in zip(self.cuda_idx, learnware_chunks):
p = Process(target=TableWorkflow.process_learnware_chunk, args=(cuda_idx, method, test_info, loss_func, learnware_chunk, queue))
processes.append(p)
p.start()
all_results = []
while any(p.is_alive() for p in processes) or not queue.empty():
try:
result = queue.get(timeout=0.1)
all_results.append(result)
bar.update(1)
except Empty:
time.sleep(0.1)
continue
bar.close()

for p in processes:
p.join()
all_results.sort(key=lambda x: x[0])
all_scores = [result[1] for result in all_results]
recorder.record(user, all_scores)
recorder.save(save_path)
process_single_aug(user, idx, recorder.data[user][idx], recorders, save_root_path)
else:
if recorder.should_test_method(user, idx, save_path):
scores = self._limited_data(method, test_info, loss_func)
recorder.record(user, scores)
recorder.save(save_path)
if recorder.should_test_method(user, idx, save_path):
scores = self._limited_data(method, test_info, loss_func)
recorder.record(user, scores)
recorder.save(save_path)
logger.info(f"Method {method_name} on {user}_{idx} finished") logger.info(f"Method {method_name} on {user}_{idx} finished")

+ 3
- 9
examples/dataset_table_workflow/hetero.py View File

@@ -109,7 +109,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
def labeled_hetero_table_example(self): def labeled_hetero_table_example(self):
logger.info("Total Items: %d" % len(self.market)) logger.info("Total Items: %d" % len(self.market))
methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"]
recorders = {method: Recorder() for method in methods + ["select_score", "oracle_score", "mean_score"]}
recorders = {method: Recorder() for method in methods}


user = self.benchmark.name user = self.benchmark.name
for idx in range(self.benchmark.user_num): for idx in range(self.benchmark.user_num):
@@ -134,10 +134,6 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
search_result = self.market.search_learnware(user_info) search_result = self.market.search_learnware(user_info)
single_result = search_result.get_single_results() single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results() multiple_result = search_result.get_multiple_results()
rank_map = {item.learnware.id: index for index, item in enumerate(single_result)}
all_learnwares = self.market.get_learnwares()
all_learnwares.sort(key=lambda learnware: rank_map.get(learnware.id, float('inf')))


if len(multiple_result) > 0: if len(multiple_result) > 0:
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
@@ -152,8 +148,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list}
method_configs = { method_configs = {
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
"hetero_single_aug": {"user_rkme": user_stat_spec, "learnwares": all_learnwares},
"hetero_multiple_aug": common_config,
"hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware},
"hetero_multiple_avg": common_config, "hetero_multiple_avg": common_config,
"hetero_ensemble_pruning": common_config "hetero_ensemble_pruning": common_config
} }
@@ -167,5 +162,4 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
for method, recorder in recorders.items(): for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
methods_to_plot = ["user_model", "select_score", "mean_score", "hetero_multiple_avg", "hetero_ensemble_pruning"]
plot_performance_curves(self.curves_result_path, user, {method: recorders[method] for method in methods_to_plot}, task="Hetero", n_labeled_list=hetero_n_labeled_list)
plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list)

+ 1
- 25
examples/dataset_table_workflow/utils.py View File

@@ -37,31 +37,7 @@ class Recorder:
self.load(path) self.load(path)
return user not in self.data or idx > len(self.data[user]) - 1 return user not in self.data or idx > len(self.data[user]) - 1
return True return True


def process_single_aug(user, idx, scores, recorders, root_path):
try:
n_labeled = len(scores[0])
select_scores, mean_scores, oracle_scores = [], [], []
for i in range(n_labeled):
sub_scores_array = np.vstack([lst[i] for lst in scores])
sub_scores_select = np.squeeze(sub_scores_array[0])
sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
sub_scores_min = np.squeeze(np.min(sub_scores_array, axis=0))
select_scores.append(sub_scores_select.tolist())
mean_scores.append(sub_scores_mean.tolist())
oracle_scores.append(sub_scores_min.tolist())

for method_name, scores in zip(["select_score", "mean_score", "oracle_score"],
[select_scores, mean_scores, oracle_scores]):
recorders[method_name].record(user, scores)
save_path = os.path.join(root_path, f"{method_name}.json")
recorders[method_name].save(save_path)
except Exception:
error_message = traceback.format_exc()
logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}")



def plot_performance_curves(path, user, recorders, task, n_labeled_list): def plot_performance_curves(path, user, recorders, task, n_labeled_list):
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))


+ 2
- 2
examples/dataset_table_workflow/workflow.py View File

@@ -38,8 +38,8 @@ class TableDatasetWorkflow:
workflow = HeterogeneousDatasetWorkflow( workflow = HeterogeneousDatasetWorkflow(
benchmark_config=hetero_cross_task_benchmark_config, benchmark_config=hetero_cross_task_benchmark_config,
name="hetero", name="hetero",
rebuild=False,
retrain=False
rebuild=True,
retrain=True
) )
workflow.labeled_hetero_table_example() workflow.labeled_hetero_table_example()




Loading…
Cancel
Save