Browse Source

[FIX] Fix the bug of reload_market

tags/v0.3.2
Gene 2 years ago
parent
commit
1ffd853c01
4 changed files with 9 additions and 5 deletions
  1. +2
    -1
      examples/example_market_db/example_db.py
  2. +1
    -0
      learnware/learnware/utils.py
  3. +4
    -2
      learnware/market/database_ops.py
  4. +2
    -2
      learnware/market/easy.py

+ 2
- 1
examples/example_market_db/example_db.py View File

@@ -13,6 +13,7 @@ curr_root = os.path.dirname(os.path.abspath(__file__))


def prepare_learnware(learnware_num=10):
np.random.seed(2023)
for i in range(learnware_num):
dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
os.makedirs(dir_path, exist_ok=True)
@@ -171,7 +172,7 @@ def test_stat_search():


if __name__ == "__main__":
learnware_num = 5
learnware_num = 10
# prepare_learnware(learnware_num)

# test_market()


+ 1
- 0
learnware/learnware/utils.py View File

@@ -50,4 +50,5 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification:
f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}"
)
stat_spec_inst.load(stat_spec["file_name"])
return stat_spec["class_name"], stat_spec_inst

+ 4
- 2
learnware/market/database_ops.py View File

@@ -1,6 +1,7 @@
import os
import json
import sqlite3
from copy import deepcopy

from ..logger import get_module_logger
from ..learnware import get_learnware_from_dirpath
@@ -79,9 +80,10 @@ def load_market_from_db(cur):
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec_dict, learnware_dirpath=folder_path
)
learnware_list[id] = new_learnware
learnware_list[id] = deepcopy(new_learnware)
zip_list[id] = zip_path
folder_list = folder_path
folder_list[id] = folder_path
max_count = max(max_count, int(id))

LOGGER.info("Market Reloaded from DB.")
return learnware_list, zip_list, folder_list, max_count + 1

+ 2
- 2
learnware/market/easy.py View File

@@ -25,7 +25,7 @@ class EasyMarket(BaseMarket):
self.learnware_folder_list = {}
self.count = 0
self.semantic_spec_list = C.semantic_specs
self.reload_market()
self.reload_market()
logger.info("Market Initialized!")

def reload_market(self) -> bool:
@@ -373,7 +373,7 @@ class EasyMarket(BaseMarket):
if "RKMEStatSpecification" not in user_info.stat_info:
return None, learnware_list, None
else:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
user_rkme = user_info.stat_info["RKMEStatSpecification"]
sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(
learnware_list, user_rkme, search_num


Loading…
Cancel
Save