diff --git a/learnware/learnware/utils.py b/learnware/learnware/utils.py index b6dccd0..d5cdc14 100644 --- a/learnware/learnware/utils.py +++ b/learnware/learnware/utils.py @@ -47,4 +47,4 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: if stat_spec_inst.load(stat_spec["file_name"]) is False: raise ValueError("Load statistic specification failed!") - return stat_spec["class_name"], stat_spec_inst + return stat_spec_inst.type, stat_spec_inst diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 957efda..098887b 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -817,7 +817,9 @@ class EasyMarket(LearnwareMarket): learnware_list = [self.learnware_list[key] for key in self.learnware_list] # learnware_list = self._search_by_semantic_spec_exact(learnware_list, user_info) # if len(learnware_list) == 0: + logger.info(f"stat_info in user_info: {user_info.stat_info}") learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) + logger.info(f"Number of learnwares after semantic fuzzy search: {len(learnware_list)}") if "RKMETableSpecification" not in user_info.stat_info: return None, learnware_list, 0.0, None diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 1c340fa..28212d6 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -83,7 +83,7 @@ class Specification: or use class name as default name """ for _v in args: - self.stat_spec[_v.__class__.__name__] = _v + self.stat_spec[_v.type] = _v for _k, _v in kwargs.items(): self.stat_spec[_k] = _v diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 82c81a2..17aedc1 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -468,7 +468,9 @@ class RKMEStatSpecification(RKMETableSpecification): TODO: modify all learnware in database and remove this nickname """ - pass + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): + super(RKMEStatSpecification, self).__init__(gamma=gamma, cuda_idx=cuda_idx) + super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__) def setup_seed(seed):