diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 125af04..a9cbc2e 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -11,6 +11,53 @@ from learnware.utils import get_module_by_module_path curr_root = os.path.dirname(os.path.abspath(__file__)) +semantic_specs = [ + { + "Data": {"Values": ["Tabular"], "Type": "Class"}, + "Task": { + "Values": ["Classification"], + "Type": "Class", + }, + "Device": {"Values": ["GPU"], "Type": "Tag"}, + "Scenario": {"Values": ["Nature"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "Description"}, + "Name": {"Values": "learnware_1", "Type": "Name"}, + }, + { + "Data": {"Values": ["Tabular"], "Type": "Class"}, + "Task": { + "Values": ["Classification"], + "Type": "Class", + }, + "Device": {"Values": ["GPU"], "Type": "Tag"}, + "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "Description"}, + "Name": {"Values": "learnware_2", "Type": "Name"}, + }, + { + "Data": {"Values": ["Tabular"], "Type": "Class"}, + "Task": { + "Values": ["Classification"], + "Type": "Class", + }, + "Device": {"Values": ["GPU"], "Type": "Tag"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "Description"}, + "Name": {"Values": "learnware_3", "Type": "Name"}, + }, +] + +user_senmantic = { + "Data": {"Values": ["Tabular"], "Type": "Class"}, + "Task": { + "Values": ["Classification"], + "Type": "Class", + }, + "Device": {"Values": ["GPU"], "Type": "Tag"}, + "Scenario": {"Values": ["Business"], "Type": "Tag"}, + "Description": {"Values": "", "Type": "Description"}, + "Name": {"Values": "", "Type": "Name"}, +} def prepare_learnware(learnware_num=10): np.random.seed(2023) @@ -56,8 +103,11 @@ def test_market(): for idx, zip_path in enumerate(zip_path_list): print(zip_path) + semantic_spec = semantic_specs[idx % 3] + semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) + semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) easy_market.add_learnware( - zip_path, {"name": "learnware_%d" % (idx), "desc": "test_learnware_number_%d" % (idx)} + zip_path, semantic_spec ) print("Total Item:", len(easy_market)) curr_inds = easy_market._get_ids() @@ -78,69 +128,25 @@ def test_search_sementics(): test_learnware_num = 3 prepare_learnware(test_learnware_num) - semantic_specs = [ - { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": [ - "Classification", - ], - "Type": "Class", - }, - "Device": {"Values": ["GPU"], "Type": "Tag"}, - "Scenario": {"Values": ["Nature"], "Type": "Tag"}, - "Description": {"Values": "", "Type": "Description"}, - }, - { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": [ - "Classification", - ], - "Type": "Class", - }, - "Device": {"Values": ["GPU"], "Type": "Tag"}, - "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"}, - "Description": {"Values": "", "Type": "Description"}, - }, - { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": [ - "Classification", - ], - "Type": "Class", - }, - "Device": {"Values": ["GPU"], "Type": "Tag"}, - "Scenario": {"Values": ["Business"], "Type": "Tag"}, - "Description": {"Values": "", "Type": "Description"}, - }, - ] - user_senmantic = { - "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": [ - "Classification", - ], - "Type": "Class", - }, - "Device": {"Values": ["GPU"], "Type": "Tag"}, - "Scenario": {"Values": ["Business"], "Type": "Tag"}, - "Description": {"Values": "learnware_1", "Type": "Description"}, - } - - for i in range(test_learnware_num): - dir_path = f"./learnware_pool/svm{i}" - model_path = os.path.join(dir_path, "__init__.py") - stat_spec_path = os.path.join(dir_path, "spec.json") - easy_market.add_learnware("learnware_%d" % (i), model_path, stat_spec_path, semantic_specs[i]) - print("Total Item:", len(easy_market)) - curr_inds = easy_market._get_ids() - print("Available ids:", curr_inds) - user_info = BaseUserInfo(id="user", semantic_spec=user_senmantic, stat_info=dict()) - learnware_list = easy_market.search_learnware(user_info) - print(learnware_list) + test_folder = "./test_stat" + zip_path_list = get_zip_path_list() + + for idx, zip_path in enumerate(zip_path_list): + unzip_dir = os.path.join(test_folder, f"{idx}") + os.makedirs(unzip_dir, exist_ok=True) + os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") + + user_spec = specification.rkme.RKMEStatSpecification() + user_spec.load(os.path.join(unzip_dir, "svm.json")) + user_info = BaseUserInfo( + id="user_0", semantic_spec=user_senmantic, stat_info={"RKME": user_spec} + ) + sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + + os.system(f"rm -r {test_folder}") + + def test_stat_search(): @@ -176,4 +182,5 @@ if __name__ == "__main__": # prepare_learnware(learnware_num) # test_market() - test_stat_search() + # test_stat_search() + test_search_sementics() diff --git a/learnware/market/easy.py b/learnware/market/easy.py index dee122f..479929a 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -314,7 +314,6 @@ class EasyMarket(BaseMarket): ) -> List[Learnware]: user_semantic_spec = user_info.get_semantic_spec() user_input_description = user_semantic_spec["Description"]["Values"] - learnware_semantic_spec = learnware.get_specification().get_semantic_spec() if not user_input_description: return [] match_learnwares = [] @@ -328,7 +327,7 @@ class EasyMarket(BaseMarket): def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: def match_semantic_tags(semantic_spec1, semantic_spec2): if semantic_spec1.keys() != semantic_spec2.keys(): - raise Exception("semantic_spec key error".format(semantic_spec1.keys(), semantic_spec2.keys())) + raise Exception("semantic_spec key error") for key in semantic_spec1.keys(): if semantic_spec1[key]["Type"] == "Class": if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]: @@ -369,6 +368,7 @@ class EasyMarket(BaseMarket): learnware_list_tags = self._search_by_semantic_tags(learnware_list, user_info) learnware_list_description = self._search_by_semantic_description(learnware_list, user_info) learnware_list = list(set(learnware_list_tags + learnware_list_description)) + # print(learnware_list_tags, learnware_list_description) if "RKMEStatSpecification" not in user_info.stat_info: return None, learnware_list, None