From 20afa2229dc95cc031cdeb520a4ff8e2c2ab0348 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 21 Apr 2023 15:52:09 +0800 Subject: [PATCH] [MNT] Fix bugs in pfs example --- examples/example_image/main.py | 4 +--- examples/example_pfs/main.py | 2 +- learnware/market/easy.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/example_image/main.py b/examples/example_image/main.py index 82ec117..89ad40d 100644 --- a/examples/example_image/main.py +++ b/examples/example_image/main.py @@ -167,13 +167,11 @@ def test_search(gamma=0.1, load_market=True): acc_list.append(acc) logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) # test reuse - """ reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list) reuse_predict = reuse_baseline.predict(user_data=user_data) reuse_score = eval_prediction(reuse_predict, user_label) job_selector_score_list.append(reuse_score) print(f"mixture reuse loss: {reuse_score}\n") - """ reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list, mode="vote") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) @@ -188,7 +186,7 @@ def test_search(gamma=0.1, load_market=True): % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) ) logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) - # logger.info("Average Job Selector Reuse Performance: %.3f +/- %.3f"%(np.mean(job_selector_score_list), np.std(job_selector_score_list))) + logger.info("Average Job Selector Reuse Performance: %.3f +/- %.3f"%(np.mean(job_selector_score_list), np.std(job_selector_score_list))) logger.info( "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index 65e247c..ae20c98 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -141,7 +141,7 @@ class PFSDatasetWorkflow: single_learnware_list, mixture_score, mixture_learnware_list, - ) = easy_market.search_learnware(user_info, search_method = "auto") + ) = easy_market.search_learnware(user_info) print(f"search result of user{idx}:") print( diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 173c509..adcc5da 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -378,7 +378,7 @@ class EasyMarket(BaseMarket): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0]) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num]