Browse Source

[MNT] Fix bugs in pfs example

tags/v0.3.2
Gene 2 years ago
parent
commit
20afa2229d
3 changed files with 3 additions and 5 deletions
  1. +1
    -3
      examples/example_image/main.py
  2. +1
    -1
      examples/example_pfs/main.py
  3. +1
    -1
      learnware/market/easy.py

+ 1
- 3
examples/example_image/main.py View File

@@ -167,13 +167,11 @@ def test_search(gamma=0.1, load_market=True):
acc_list.append(acc)
logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
# test reuse
"""
reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list)
reuse_predict = reuse_baseline.predict(user_data=user_data)
reuse_score = eval_prediction(reuse_predict, user_label)
job_selector_score_list.append(reuse_score)
print(f"mixture reuse loss: {reuse_score}\n")
"""

reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list, mode="vote")
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
@@ -188,7 +186,7 @@ def test_search(gamma=0.1, load_market=True):
% (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
)
logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
# logger.info("Average Job Selector Reuse Performance: %.3f +/- %.3f"%(np.mean(job_selector_score_list), np.std(job_selector_score_list)))
logger.info("Average Job Selector Reuse Performance: %.3f +/- %.3f"%(np.mean(job_selector_score_list), np.std(job_selector_score_list)))
logger.info(
"Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
)


+ 1
- 1
examples/example_pfs/main.py View File

@@ -141,7 +141,7 @@ class PFSDatasetWorkflow:
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info, search_method = "auto")
) = easy_market.search_learnware(user_info)

print(f"search result of user{idx}:")
print(


+ 1
- 1
learnware/market/easy.py View File

@@ -378,7 +378,7 @@ class EasyMarket(BaseMarket):
if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0])
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification"))
else:
if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num]


Loading…
Cancel
Save