Browse Source

Merge pull request #36 from Learnware-LAMDA/fix_market_bug

[FIX] fix bug in easy market
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
eb4bab385b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 3 deletions
  1. +1
    -1
      learnware/learnware/utils.py
  2. +2
    -0
      learnware/market/easy.py
  3. +1
    -1
      learnware/specification/base.py
  4. +3
    -1
      learnware/specification/regular/table/rkme.py

+ 1
- 1
learnware/learnware/utils.py View File

@@ -47,4 +47,4 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification:
if stat_spec_inst.load(stat_spec["file_name"]) is False:
raise ValueError("Load statistic specification failed!")

return stat_spec["class_name"], stat_spec_inst
return stat_spec_inst.type, stat_spec_inst

+ 2
- 0
learnware/market/easy.py View File

@@ -817,7 +817,9 @@ class EasyMarket(LearnwareMarket):
learnware_list = [self.learnware_list[key] for key in self.learnware_list]
# learnware_list = self._search_by_semantic_spec_exact(learnware_list, user_info)
# if len(learnware_list) == 0:
logger.info(f"stat_info in user_info: {user_info.stat_info}")
learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info)
logger.info(f"Number of learnwares after semantic fuzzy search: {len(learnware_list)}")

if "RKMETableSpecification" not in user_info.stat_info:
return None, learnware_list, 0.0, None


+ 1
- 1
learnware/specification/base.py View File

@@ -83,7 +83,7 @@ class Specification:
or use class name as default name
"""
for _v in args:
self.stat_spec[_v.__class__.__name__] = _v
self.stat_spec[_v.type] = _v

for _k, _v in kwargs.items():
self.stat_spec[_k] = _v


+ 3
- 1
learnware/specification/regular/table/rkme.py View File

@@ -468,7 +468,9 @@ class RKMEStatSpecification(RKMETableSpecification):
TODO: modify all learnware in database and remove this nickname
"""

pass
def __init__(self, gamma: float = 0.1, cuda_idx: int = -1):
super(RKMEStatSpecification, self).__init__(gamma=gamma, cuda_idx=cuda_idx)
super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__)


def setup_seed(seed):


Loading…
Cancel
Save