From 445930c4021a6c7333a6b9cf394ee2753a4ec93f Mon Sep 17 00:00:00 2001 From: xiey Date: Wed, 19 Apr 2023 18:52:36 +0800 Subject: [PATCH 1/6] [MNT] Add a log info in semantic search --- learnware/market/easy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 0133fa1..1e4d291 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -581,6 +581,7 @@ class EasyMarket(BaseMarket): user_semantic_spec = user_info.get_semantic_spec() if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): match_learnwares.append(learnware) + logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) return match_learnwares def search_learnware( From 180747757284a7ae91ff9d827bbea5e0c24bd848 Mon Sep 17 00:00:00 2001 From: liuht Date: Wed, 19 Apr 2023 19:06:59 +0800 Subject: [PATCH 2/6] [] --- learnware/market/easy.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 0133fa1..19e2c6f 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -23,7 +23,7 @@ class EasyMarket(BaseMarket): NOPREDICTION_LEARNWARE = 0 PREDICTION_LEARWARE = 1 - def __init__(self, market_id: str = None, rebuild: bool = False): + def __init__(self, rebuild: bool = False): """Initialize Learnware Market. Automatically reload from db if available. Build an empty db otherwise. @@ -333,7 +333,7 @@ class EasyMarket(BaseMarket): learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int, - weight_cutoff: float = 0.95, + weight_cutoff: float = 0.98 ) -> Tuple[List[float], List[Learnware]]: """Select learnwares based on a total mixture ratio, then recalculate their mixture weights @@ -372,15 +372,15 @@ class EasyMarket(BaseMarket): mixture_list.append(learnware_list[idx]) else: break - + if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] else: if len(mixture_list) > max_search_num: - mixture_list = mixture_list[:max_search_num] + mixture_list = mixture_list[:max_search_num] mixture_weight, _ = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme) - + return mixture_weight, mixture_list def _filter_by_rkme_spec_single( @@ -444,12 +444,12 @@ class EasyMarket(BaseMarket): return filtered_learnware_list - def _search_by_rkme_spec_mixture_greedy( + def _search_by_rkme_spec_mixture( self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int, - score_cutoff: float = 0.01, + score_cutoff: float = 0.001, ) -> Tuple[List[float], List[Learnware]]: """Greedily match learnwares such that their mixture become more and more closer to user's rkme @@ -618,11 +618,11 @@ class EasyMarket(BaseMarket): sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( sorted_score_list, single_learnware_list ) - if search_method == "auto": + if search_method == 'auto': weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( learnware_list, user_rkme, max_search_num ) - elif search_method == "greedy": + elif search_method == 'greedy': weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( learnware_list, user_rkme, max_search_num ) From eb31bf15b2d70df3871161c395461b2b4ad92728 Mon Sep 17 00:00:00 2001 From: liuht Date: Wed, 19 Apr 2023 19:11:52 +0800 Subject: [PATCH 3/6] [MNT] change parameters --- learnware/market/easy.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index fee3af9..0efd5bb 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -23,7 +23,7 @@ class EasyMarket(BaseMarket): NOPREDICTION_LEARNWARE = 0 PREDICTION_LEARWARE = 1 - def __init__(self, rebuild: bool = False): + def __init__(self, market_id: str = None, rebuild: bool = False): """Initialize Learnware Market. Automatically reload from db if available. Build an empty db otherwise. @@ -333,7 +333,7 @@ class EasyMarket(BaseMarket): learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int, - weight_cutoff: float = 0.98 + weight_cutoff: float = 0.98, ) -> Tuple[List[float], List[Learnware]]: """Select learnwares based on a total mixture ratio, then recalculate their mixture weights @@ -372,15 +372,15 @@ class EasyMarket(BaseMarket): mixture_list.append(learnware_list[idx]) else: break - + if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] else: if len(mixture_list) > max_search_num: - mixture_list = mixture_list[:max_search_num] + mixture_list = mixture_list[:max_search_num] mixture_weight, _ = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme) - + return mixture_weight, mixture_list def _filter_by_rkme_spec_single( @@ -444,7 +444,7 @@ class EasyMarket(BaseMarket): return filtered_learnware_list - def _search_by_rkme_spec_mixture( + def _search_by_rkme_spec_mixture_greedy( self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, @@ -619,11 +619,11 @@ class EasyMarket(BaseMarket): sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( sorted_score_list, single_learnware_list ) - if search_method == 'auto': + if search_method == "auto": weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( learnware_list, user_rkme, max_search_num ) - elif search_method == 'greedy': + elif search_method == "greedy": weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( learnware_list, user_rkme, max_search_num ) From 48c54223203d60ea9a846bc636fbeb818d156a92 Mon Sep 17 00:00:00 2001 From: xiey Date: Wed, 19 Apr 2023 19:15:54 +0800 Subject: [PATCH 4/6] [MNT] add a raise by bixd --- learnware/market/easy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 1e4d291..ddcbb2e 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -73,6 +73,7 @@ class EasyMarket(BaseMarket): learnware.instantiate_model() except Exception as e: logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {repr(e)}") + raise return cls.INVALID_LEARNWARE try: From ae7ad798744dcb032e1832e2d5de054377baf1ce Mon Sep 17 00:00:00 2001 From: Gene Date: Wed, 19 Apr 2023 19:26:41 +0800 Subject: [PATCH 5/6] [MNT] Add save spec in example --- examples/example_pfs/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index fd943c9..c6708db 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -154,6 +154,8 @@ class PFSDatasetWorkflow: for idx in idx_list: train_x, train_y, test_x, test_y = pfs.get_idx_data(idx) user_spec = specification.utils.generate_rkme_spec(X=test_x, gamma=0.1, cuda_idx=0) + user_spec_path = f"./user_spec/user_{idx}.json" + user_spec.save(user_spec_path) user_info = BaseUserInfo( id=f"user_{idx}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec} From 367ec6c3e80190944a99b46a429e90675626c96f Mon Sep 17 00:00:00 2001 From: Gene Date: Wed, 19 Apr 2023 19:29:13 +0800 Subject: [PATCH 6/6] [FIX] Add makedir --- examples/example_pfs/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index f9b7bc1..ddcb0c2 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -122,6 +122,7 @@ class PFSDatasetWorkflow: pfs = Dataloader() idx_list = pfs.get_idx_list() + os.makedirs("./user_spec", exist_ok=True) for idx in idx_list: train_x, train_y, test_x, test_y = pfs.get_idx_data(idx)