Browse Source

[MNT] set default values

tags/v0.3.2
liuht 2 years ago
parent
commit
3d61d993d4
4 changed files with 26 additions and 37 deletions
  1. +1
    -15
      examples/dataset_table_workflow/base.py
  2. +1
    -1
      examples/dataset_table_workflow/config.py
  3. +19
    -19
      examples/dataset_table_workflow/hetero.py
  4. +5
    -2
      examples/dataset_table_workflow/workflow.py

+ 1
- 15
examples/dataset_table_workflow/base.py View File

@@ -19,9 +19,6 @@ from utils import set_seed
logger = get_module_logger("base_table", level="INFO")


# for quick test only
from learnware.market.heterogeneous import HeteroMapTableOrganizer

class TableWorkflow:
def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False):
self.root_path = os.path.abspath(os.path.join(__file__, ".."))
@@ -29,10 +26,7 @@ class TableWorkflow:
self.curves_result_path = os.path.join(self.result_path, "curves")
os.makedirs(self.result_path, exist_ok=True)
os.makedirs(self.curves_result_path, exist_ok=True)
# if name == "hetero":
# set_seed(42)
self._prepare_market(benchmark_config, name, rebuild, retrain)
self._prepare_market(benchmark_config, name, rebuild, retrain)
@staticmethod
def _limited_data(method, test_info, loss_func):
@@ -81,14 +75,6 @@ class TableWorkflow:
self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0])
self.user_semantic["Name"]["Values"] = ""
# if retrain == True and rebuild == False:
# logger.info(f"training learnwares: {self.market.get_learnware_ids()[::-1]}")
# market_mapping = HeteroMapTableOrganizer.train(self.market.get_learnwares()[::-1], save_dir='test_model.bin', **market_mapping_params)
# self.market.learnware_organizer.market_mapping = market_mapping
# self.market.learnware_organizer._update_learnware_hetero_spec(self.market.get_learnware_ids()[::-1])
# return
if len(self.market) == 0 or rebuild == True:
for learnware_id in self.benchmark.learnware_ids:
with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir:


+ 1
- 1
examples/dataset_table_workflow/config.py View File

@@ -216,7 +216,7 @@ hetero_cross_feat_eng_benchmark_config = BenchmarkConfig(
"00000912"
],
test_data_path="PFS/test_data.zip",
# train_data_path="PFS/train_data.zip",
train_data_path="PFS/train_data.zip",
extra_info_path="PFS/extra_info.zip"
)



+ 19
- 19
examples/dataset_table_workflow/hetero.py View File

@@ -39,7 +39,7 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
)
logger.info(f"Searching Market for user: {user}_{idx}")
search_result = self.market.search_learnware(user_info)
search_result = self.market.search_learnware(user_info, max_search_num=10)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
@@ -53,15 +53,15 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
pred_y = single_hetero_learnware.predict(test_x)
single_score_list.append(loss_func_rmse(pred_y, test_y))

# rmse_list = []
# for learnware in all_learnwares:
# hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
# hetero_learnware.align(user_rkme=user_stat_spec)
# pred_y = hetero_learnware.predict(test_x)
# rmse_list.append(loss_func_rmse(pred_y, test_y))
# logger.info(
# f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}"
# )
rmse_list = []
for learnware in all_learnwares:
hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
hetero_learnware.align(user_rkme=user_stat_spec)
pred_y = hetero_learnware.predict(test_x)
rmse_list.append(loss_func_rmse(pred_y, test_y))
logger.info(
f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}"
)
if len(multiple_result) > 0:
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
@@ -83,21 +83,21 @@ class HeterogeneousDatasetWorkflow(TableWorkflow):
ensemble_score_list.append(ensemble_score)
logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")

# learnware_rmse_list.append(rmse_list)
learnware_rmse_list.append(rmse_list)
# single_list = np.array(learnware_rmse_list)
# avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
# oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
single_list = np.array(learnware_rmse_list)
avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
logger.info(
"RMSE of selected learnware: %.3f +/- %.3f" # , Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
"RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
% (
np.mean(single_score_list),
np.std(single_score_list),
# np.mean(avg_score_list),
# np.std(avg_score_list),
# np.mean(oracle_score_list),
# np.std(oracle_score_list),
np.mean(avg_score_list),
np.std(avg_score_list),
np.mean(oracle_score_list),
np.std(oracle_score_list)
)
)
logger.info(


+ 5
- 2
examples/dataset_table_workflow/workflow.py View File

@@ -4,6 +4,7 @@ from learnware.logger import get_module_logger
from homo import HomogeneousDatasetWorkflow
from hetero import HeterogeneousDatasetWorkflow
from config import homo_table_benchmark_config, hetero_cross_feat_eng_benchmark_config, hetero_cross_task_benchmark_config
from utils import set_seed

logger = get_module_logger("base_table", level="INFO")

@@ -26,15 +27,17 @@ class TableDatasetWorkflow:
workflow.labeled_homo_table_example()
def cross_feat_eng_hetero_table_example(self):
set_seed(0)
workflow = HeterogeneousDatasetWorkflow(
benchmark_config=hetero_cross_feat_eng_benchmark_config,
name="hetero",
rebuild=True,
retrain=True
rebuild=False,
retrain=False
)
workflow.unlabeled_hetero_table_example()

def cross_task_hetero_table_example(self):
set_seed(0)
workflow = HeterogeneousDatasetWorkflow(
benchmark_config=hetero_cross_task_benchmark_config,
name="hetero",


Loading…
Cancel
Save