Browse Source

Merge branch 'dev' of git.nju.edu.cn:learnware/learnware-market into dev

tags/v0.3.2
bxdd 3 years ago
parent
commit
365314aa2c
2 changed files with 7 additions and 2 deletions
  1. +3
    -0
      examples/example_pfs/main.py
  2. +4
    -2
      learnware/market/easy.py

+ 3
- 0
examples/example_pfs/main.py View File

@@ -122,10 +122,13 @@ class PFSDatasetWorkflow:

pfs = Dataloader()
idx_list = pfs.get_idx_list()
os.makedirs("./user_spec", exist_ok=True)

for idx in idx_list:
train_x, train_y, test_x, test_y = pfs.get_idx_data(idx)
user_spec = specification.utils.generate_rkme_spec(X=test_x, gamma=0.1, cuda_idx=0)
user_spec_path = f"./user_spec/user_{idx}.json"
user_spec.save(user_spec_path)

user_info = BaseUserInfo(
id=f"user_{idx}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}


+ 4
- 2
learnware/market/easy.py View File

@@ -79,6 +79,7 @@ class EasyMarket(BaseMarket):
learnware.instantiate_model()
except Exception as e:
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {repr(e)}")
raise
return cls.INVALID_LEARNWARE

try:
@@ -340,7 +341,7 @@ class EasyMarket(BaseMarket):
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
max_search_num: int,
weight_cutoff: float = 0.95,
weight_cutoff: float = 0.98,
) -> Tuple[List[float], List[Learnware]]:
"""Select learnwares based on a total mixture ratio, then recalculate their mixture weights

@@ -456,7 +457,7 @@ class EasyMarket(BaseMarket):
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
max_search_num: int,
score_cutoff: float = 0.01,
score_cutoff: float = 0.001,
) -> Tuple[List[float], List[Learnware]]:
"""Greedily match learnwares such that their mixture become more and more closer to user's rkme

@@ -588,6 +589,7 @@ class EasyMarket(BaseMarket):
user_semantic_spec = user_info.get_semantic_spec()
if match_semantic_spec(learnware_semantic_spec, user_semantic_spec):
match_learnwares.append(learnware)
logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list)))
return match_learnwares

def search_learnware(


Loading…
Cancel
Save