diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index ae20c98..65e247c 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -141,7 +141,7 @@ class PFSDatasetWorkflow: single_learnware_list, mixture_score, mixture_learnware_list, - ) = easy_market.search_learnware(user_info) + ) = easy_market.search_learnware(user_info, search_method = "auto") print(f"search result of user{idx}:") print( diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 944342d..4865552 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -385,7 +385,7 @@ class EasyMarket(BaseMarket): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list) + mmd_dist = user_rkme.dist(mixture_list[0]) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num]