Browse Source

Merge branch 'dev' of git.nju.edu.cn:learnware/learnware-market into dev

tags/v0.3.2
Gene 3 years ago
parent
commit
0dadcb56e2
2 changed files with 73 additions and 66 deletions
  1. +71
    -64
      examples/example_market_db/example_db.py
  2. +2
    -2
      learnware/market/easy.py

+ 71
- 64
examples/example_market_db/example_db.py View File

@@ -11,6 +11,53 @@ from learnware.utils import get_module_by_module_path

curr_root = os.path.dirname(os.path.abspath(__file__))

semantic_specs = [
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
"Name": {"Values": "learnware_1", "Type": "Name"},
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
"Name": {"Values": "learnware_2", "Type": "Name"},
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
"Name": {"Values": "learnware_3", "Type": "Name"},
},
]

user_senmantic = {
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
"Name": {"Values": "", "Type": "Name"},
}

def prepare_learnware(learnware_num=10):
np.random.seed(2023)
@@ -56,8 +103,11 @@ def test_market():

for idx, zip_path in enumerate(zip_path_list):
print(zip_path)
semantic_spec = semantic_specs[idx % 3]
semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
easy_market.add_learnware(
zip_path, {"name": "learnware_%d" % (idx), "desc": "test_learnware_number_%d" % (idx)}
zip_path, semantic_spec
)
print("Total Item:", len(easy_market))
curr_inds = easy_market._get_ids()
@@ -78,69 +128,25 @@ def test_search_sementics():
test_learnware_num = 3
prepare_learnware(test_learnware_num)

semantic_specs = [
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": [
"Classification",
],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": [
"Classification",
],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": [
"Classification",
],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
},
]
user_senmantic = {
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": [
"Classification",
],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "learnware_1", "Type": "Description"},
}

for i in range(test_learnware_num):
dir_path = f"./learnware_pool/svm{i}"
model_path = os.path.join(dir_path, "__init__.py")
stat_spec_path = os.path.join(dir_path, "spec.json")
easy_market.add_learnware("learnware_%d" % (i), model_path, stat_spec_path, semantic_specs[i])
print("Total Item:", len(easy_market))
curr_inds = easy_market._get_ids()
print("Available ids:", curr_inds)

user_info = BaseUserInfo(id="user", semantic_spec=user_senmantic, stat_info=dict())
learnware_list = easy_market.search_learnware(user_info)
print(learnware_list)
test_folder = "./test_stat"
zip_path_list = get_zip_path_list()

for idx, zip_path in enumerate(zip_path_list):
unzip_dir = os.path.join(test_folder, f"{idx}")
os.makedirs(unzip_dir, exist_ok=True)
os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")

user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(
id="user_0", semantic_spec=user_senmantic, stat_info={"RKME": user_spec}
)
sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)

os.system(f"rm -r {test_folder}")



def test_stat_search():
@@ -176,4 +182,5 @@ if __name__ == "__main__":
# prepare_learnware(learnware_num)

# test_market()
test_stat_search()
# test_stat_search()
test_search_sementics()

+ 2
- 2
learnware/market/easy.py View File

@@ -314,7 +314,6 @@ class EasyMarket(BaseMarket):
) -> List[Learnware]:
user_semantic_spec = user_info.get_semantic_spec()
user_input_description = user_semantic_spec["Description"]["Values"]
learnware_semantic_spec = learnware.get_specification().get_semantic_spec()
if not user_input_description:
return []
match_learnwares = []
@@ -328,7 +327,7 @@ class EasyMarket(BaseMarket):
def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]:
def match_semantic_tags(semantic_spec1, semantic_spec2):
if semantic_spec1.keys() != semantic_spec2.keys():
raise Exception("semantic_spec key error".format(semantic_spec1.keys(), semantic_spec2.keys()))
raise Exception("semantic_spec key error")
for key in semantic_spec1.keys():
if semantic_spec1[key]["Type"] == "Class":
if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]:
@@ -369,6 +368,7 @@ class EasyMarket(BaseMarket):
learnware_list_tags = self._search_by_semantic_tags(learnware_list, user_info)
learnware_list_description = self._search_by_semantic_description(learnware_list, user_info)
learnware_list = list(set(learnware_list_tags + learnware_list_description))
# print(learnware_list_tags, learnware_list_description)

if "RKMEStatSpecification" not in user_info.stat_info:
return None, learnware_list, None


Loading…
Cancel
Save