Browse Source

Merge branch 'dev' of ssh://git.nju.edu.cn/learnware/learnware-market into dev

tags/v0.3.2
liuht 3 years ago
parent
commit
34bbdff6a4
6 changed files with 48 additions and 75 deletions
  1. +21
    -28
      examples/example_market_db/example_db.py
  2. +4
    -16
      learnware/config.py
  3. +1
    -4
      learnware/learnware/__init__.py
  4. +14
    -22
      learnware/market/easy.py
  5. +6
    -1
      learnware/specification/base.py
  6. +2
    -4
      learnware/specification/rkme.py

+ 21
- 28
examples/example_market_db/example_db.py View File

@@ -14,10 +14,7 @@ curr_root = os.path.dirname(os.path.abspath(__file__))
semantic_specs = [
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Task": {"Values": ["Classification"], "Type": "Class"},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
@@ -25,10 +22,7 @@ semantic_specs = [
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Task": {"Values": ["Classification"], "Type": "Class"},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
@@ -36,10 +30,7 @@ semantic_specs = [
},
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Task": {"Values": ["Regression"], "Type": "Class"},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
@@ -49,14 +40,11 @@ semantic_specs = [

user_senmantic = {
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Task": {"Values": ["Classification"], "Type": "Class",},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
"Name": {"Values": "", "Type": "Name"},
"Name": {"Values": "learnware_4", "Type": "Name"},
}


@@ -117,7 +105,7 @@ def test_market():
print("Available ids:", curr_inds)


def test_search_sementics():
def test_search_semantics():
easy_market = EasyMarket()
print("Total Item:", len(easy_market))

@@ -129,15 +117,20 @@ def test_search_sementics():
test_folder = "./test_stat"
zip_path_list = get_zip_path_list()

for idx, zip_path in enumerate(zip_path_list):
unzip_dir = os.path.join(test_folder, f"{idx}")
os.makedirs(unzip_dir, exist_ok=True)
os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
idx, zip_path = 1, zip_path_list[1]
unzip_dir = os.path.join(test_folder, f"{idx}")
os.makedirs(unzip_dir, exist_ok=True)
os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")

user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic, stat_info={"RKME": user_spec})
sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic)
_, single_learnware_list, _ = easy_market.search_learnware(user_info)

print("User info:", user_info.get_semantic_spec())
print(f"search result of user{idx}:")
for learnware in single_learnware_list:
print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())

os.system(f"rm -r {test_folder}")

@@ -174,5 +167,5 @@ if __name__ == "__main__":
learnware_num = 10
prepare_learnware(learnware_num)
test_market()
test_stat_search()
test_search_sementics()
# test_stat_search()
test_search_semantics()

+ 4
- 16
learnware/config.py View File

@@ -57,10 +57,7 @@ os.makedirs(LEARNWARE_ZIP_POOL_PATH, exist_ok=True)
os.makedirs(LEARNWARE_FOLDER_POOL_PATH, exist_ok=True)

semantic_config = {
"Data": {
"Values": ["Tabular", "Image", "Video", "Text", "Audio"],
"Type": "Class", # Choose only one class
},
"Data": {"Values": ["Tabular", "Image", "Video", "Text", "Audio"], "Type": "Class",}, # Choose only one class
"Task": {
"Values": [
"Classification",
@@ -73,10 +70,7 @@ semantic_config = {
],
"Type": "Class", # Choose only one class
},
"Device": {
"Values": ["CPU", "GPU"],
"Type": "Tag", # Choose one or more tags
},
"Device": {"Values": ["CPU", "GPU"], "Type": "Tag",}, # Choose one or more tags
"Scenario": {
"Values": [
"Business",
@@ -96,14 +90,8 @@ semantic_config = {
],
"Type": "Tag", # Choose one or more tags
},
"Description": {
"Values": None,
"Type": "Description",
},
"Name": {
"Values": None,
"Type": "Name",
},
"Description": {"Values": None, "Type": "Description",},
"Name": {"Values": None, "Type": "Name",},
}

_DEFAULT_CONFIG = {


+ 1
- 4
learnware/learnware/__init__.py View File

@@ -28,10 +28,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath:
The contructed learnware object, return None if build failed
"""
learnware_config = {
"model": {
"class_name": "Model",
"kwargs": {},
},
"model": {"class_name": "Model", "kwargs": {},},
"stat_specifications": [
{
"module_path": "learnware.specification",


+ 14
- 22
learnware/market/easy.py View File

@@ -117,10 +117,7 @@ class EasyMarket(BaseMarket):
self.learnware_folder_list[id] = target_folder_dir
self.count += 1
add_learnware_to_db(
id,
semantic_spec=semantic_spec,
zip_path=target_folder_dir,
folder_path=target_folder_dir,
id, semantic_spec=semantic_spec, zip_path=target_folder_dir, folder_path=target_folder_dir,
)
return id, True
@@ -333,21 +330,6 @@ class EasyMarket(BaseMarket):

return sorted_dist_list, sorted_learnware_list

def _search_by_semantic_description(
self, learnware_list: List[Learnware], user_info: BaseUserInfo
) -> List[Learnware]:
user_semantic_spec = user_info.get_semantic_spec()
user_input_description = user_semantic_spec["Description"]["Values"]
if not user_input_description:
return []
match_learnwares = []
for learnware in learnware_list:
learnware_semantic_spec = learnware.get_specification().get_semantic_spec()
learnware_name = learnware_semantic_spec["Name"]["Values"]
if user_input_description in learnware_name:
match_learnwares.append(learnware)
return match_learnwares

def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]:
def match_semantic_tags(semantic_spec1, semantic_spec2):
if semantic_spec1.keys() != semantic_spec2.keys():
@@ -355,12 +337,23 @@ class EasyMarket(BaseMarket):
logger.warning("semantic_spec key error!")
return False
for key in semantic_spec1.keys():
if len(semantic_spec1[key]["Values"]) == 0:
continue
if len(semantic_spec2[key]["Values"]) == 0:
continue
if semantic_spec1[key]["Type"] == "Class":
if isinstance(semantic_spec1[key]["Values"], list):
semantic_spec1[key]["Values"] = semantic_spec1[key]["Values"][0]
if isinstance(semantic_spec2[key]["Values"], list):
semantic_spec2[key]["Values"] = semantic_spec2[key]["Values"][0]
if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]:
return False
elif semantic_spec1[key]["Type"] == "Tag":
if not (set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"])):
return False
elif semantic_spec1[key]["Type"] == "Name":
if semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"]:
return False
return True

match_learnwares = []
@@ -391,9 +384,8 @@ class EasyMarket(BaseMarket):
the third is the list of Learnware (mixture), the size is search_num
"""
learnware_list = [self.learnware_list[key] for key in self.learnware_list]
learnware_list_tags = self._search_by_semantic_tags(learnware_list, user_info)
learnware_list_description = self._search_by_semantic_description(learnware_list, user_info)
learnware_list = list(set(learnware_list_tags + learnware_list_description))
learnware_list = self._search_by_semantic_tags(learnware_list, user_info)
# learnware_list = list(set(learnware_list_tags + learnware_list_description))

if "RKMEStatSpecification" not in user_info.stat_info:
return None, learnware_list, None


+ 6
- 1
learnware/specification/base.py View File

@@ -6,7 +6,12 @@ class BaseStatSpecification:
def __init__(self):
pass

def generate_stat_spec_from_data(self, X: np.ndarray):
def generate_stat_spec_from_data(self, **kwargs):
"""Construct statistical specification from raw dataset

- kwargs may include the feature, label and model
- kwargs also can include hyperparameters of specific method for specifaction generation
"""
raise NotImplementedError("generate_stat_spec_from_data is not implemented")

def save(self, filepath: str):


+ 2
- 4
learnware/specification/rkme.py View File

@@ -255,9 +255,7 @@ class RKMEStatSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"
json.dump(
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),
separators=(",", ":"),
rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), separators=(",", ":"),
)

def load(self, filepath: str) -> bool:
@@ -345,7 +343,7 @@ def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor:
"""
x1 = x1.double()
x2 = x2.double()
X12norm = torch.sum(x1**2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2**2, 1, keepdim=True).T
X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T
return torch.exp(-X12norm * gamma)




Loading…
Cancel
Save