diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index 580ce8d..ddcb0c2 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -122,10 +122,13 @@ class PFSDatasetWorkflow: pfs = Dataloader() idx_list = pfs.get_idx_list() + os.makedirs("./user_spec", exist_ok=True) for idx in idx_list: train_x, train_y, test_x, test_y = pfs.get_idx_data(idx) user_spec = specification.utils.generate_rkme_spec(X=test_x, gamma=0.1, cuda_idx=0) + user_spec_path = f"./user_spec/user_{idx}.json" + user_spec.save(user_spec_path) user_info = BaseUserInfo( id=f"user_{idx}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec} diff --git a/learnware/market/easy.py b/learnware/market/easy.py index d15b5b7..b5b772c 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -79,6 +79,7 @@ class EasyMarket(BaseMarket): learnware.instantiate_model() except Exception as e: logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {repr(e)}") + raise return cls.INVALID_LEARNWARE try: @@ -340,7 +341,7 @@ class EasyMarket(BaseMarket): learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int, - weight_cutoff: float = 0.95, + weight_cutoff: float = 0.98, ) -> Tuple[List[float], List[Learnware]]: """Select learnwares based on a total mixture ratio, then recalculate their mixture weights @@ -456,7 +457,7 @@ class EasyMarket(BaseMarket): learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int, - score_cutoff: float = 0.01, + score_cutoff: float = 0.001, ) -> Tuple[List[float], List[Learnware]]: """Greedily match learnwares such that their mixture become more and more closer to user's rkme @@ -588,6 +589,7 @@ class EasyMarket(BaseMarket): user_semantic_spec = user_info.get_semantic_spec() if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): match_learnwares.append(learnware) + logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) return match_learnwares def search_learnware(