From ee25c45945bd131fff983cd9549b45ad6553e931 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 10 Apr 2023 16:22:49 +0800 Subject: [PATCH 1/5] [MNT] Fix typo, modify generate spec interface --- examples/example_market_db/example_db.py | 4 ++-- learnware/market/easy.py | 2 +- learnware/specification/base.py | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 4a35c8d..32aeb07 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -117,7 +117,7 @@ def test_market(): print("Available ids:", curr_inds) -def test_search_sementics(): +def test_search_semantics(): easy_market = EasyMarket() print("Total Item:", len(easy_market)) @@ -175,4 +175,4 @@ if __name__ == "__main__": prepare_learnware(learnware_num) test_market() test_stat_search() - test_search_sementics() + test_search_semantics() diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 42827b2..1748bce 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -244,7 +244,7 @@ class EasyMarket(BaseMarket): learnware_num = len(learnware_list) if learnware_num == 0: return [], [] - if learnware_num < search_num: + if learnware_num < search_num: logger.warning("Available Learnware num less than search_num") search_num = learnware_num diff --git a/learnware/specification/base.py b/learnware/specification/base.py index e5042b4..d9e208a 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -6,7 +6,12 @@ class BaseStatSpecification: def __init__(self): pass - def generate_stat_spec_from_data(self, X: np.ndarray): + def generate_stat_spec_from_data(self, **kwargs): + """Construct reduced set from raw dataset using iterative optimization + + - kwargs may include the feature, label and model + - kwargs also can include hyperparameter for specifaction generation + """ raise NotImplementedError("generate_stat_spec_from_data is not implemented") def save(self, filepath: str): From 06754a9ca4af91fc90d031138800c6331dd8418c Mon Sep 17 00:00:00 2001 From: xiey Date: Mon, 10 Apr 2023 16:59:32 +0800 Subject: [PATCH 2/5] [MNT] Update test_search_semantics --- examples/example_market_db/example_db.py | 45 ++++++++++-------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 32aeb07..1cb1425 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -14,10 +14,7 @@ curr_root = os.path.dirname(os.path.abspath(__file__)) semantic_specs = [ { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -25,10 +22,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -36,10 +30,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -49,10 +40,7 @@ semantic_specs = [ user_senmantic = { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -129,15 +117,20 @@ def test_search_semantics(): test_folder = "./test_stat" zip_path_list = get_zip_path_list() - for idx, zip_path in enumerate(zip_path_list): - unzip_dir = os.path.join(test_folder, f"{idx}") - os.makedirs(unzip_dir, exist_ok=True) - os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") - - user_spec = specification.rkme.RKMEStatSpecification() - user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic, stat_info={"RKME": user_spec}) - sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + idx, zip_path = 1, zip_path_list[1] + unzip_dir = os.path.join(test_folder, f"{idx}") + os.makedirs(unzip_dir, exist_ok=True) + os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") + + user_spec = specification.rkme.RKMEStatSpecification() + user_spec.load(os.path.join(unzip_dir, "svm.json")) + user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic) + _, single_learnware_list, _ = easy_market.search_learnware(user_info) + + print("User info:", user_info.get_semantic_spec()) + print(f"search result of user{idx}:") + for learnware in single_learnware_list: + print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) os.system(f"rm -r {test_folder}") @@ -174,5 +167,5 @@ if __name__ == "__main__": learnware_num = 10 prepare_learnware(learnware_num) test_market() - test_stat_search() + # test_stat_search() test_search_semantics() From 15839d302c34ed860cdb930b6b71213ce099e921 Mon Sep 17 00:00:00 2001 From: Peng Tan Date: Mon, 10 Apr 2023 18:54:27 +0800 Subject: [PATCH 3/5] [DOC] modify description of generate_stat_spec_from_data --- learnware/specification/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/learnware/specification/base.py b/learnware/specification/base.py index d9e208a..449cfd3 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -7,10 +7,10 @@ class BaseStatSpecification: pass def generate_stat_spec_from_data(self, **kwargs): - """Construct reduced set from raw dataset using iterative optimization + """Construct statistical specification from raw dataset - kwargs may include the feature, label and model - - kwargs also can include hyperparameter for specifaction generation + - kwargs also can include hyperparameters of specific method for specifaction generation """ raise NotImplementedError("generate_stat_spec_from_data is not implemented") From 68e1964ef8e8a59e416778e26f86d1fc95b6c615 Mon Sep 17 00:00:00 2001 From: xiey Date: Mon, 10 Apr 2023 20:21:24 +0800 Subject: [PATCH 4/5] [MNT] Change the logic of search by semantic --- examples/example_market_db/example_db.py | 4 ++-- learnware/config.py | 20 ++++------------ learnware/learnware/__init__.py | 5 +--- learnware/market/easy.py | 29 +++++++++++++++--------- learnware/specification/rkme.py | 6 ++--- 5 files changed, 27 insertions(+), 37 deletions(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 1cb1425..eb86980 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -44,7 +44,7 @@ user_senmantic = { "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, - "Name": {"Values": "", "Type": "Name"}, + "Name": {"Values": "learnware_4", "Type": "Name"}, } @@ -126,7 +126,7 @@ def test_search_semantics(): user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic) _, single_learnware_list, _ = easy_market.search_learnware(user_info) - + print("User info:", user_info.get_semantic_spec()) print(f"search result of user{idx}:") for learnware in single_learnware_list: diff --git a/learnware/config.py b/learnware/config.py index f2c650b..5458374 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -57,10 +57,7 @@ os.makedirs(LEARNWARE_ZIP_POOL_PATH, exist_ok=True) os.makedirs(LEARNWARE_FOLDER_POOL_PATH, exist_ok=True) semantic_config = { - "Data": { - "Values": ["Tabular", "Image", "Video", "Text", "Audio"], - "Type": "Class", # Choose only one class - }, + "Data": {"Values": ["Tabular", "Image", "Video", "Text", "Audio"], "Type": "Class",}, # Choose only one class "Task": { "Values": [ "Classification", @@ -73,10 +70,7 @@ semantic_config = { ], "Type": "Class", # Choose only one class }, - "Device": { - "Values": ["CPU", "GPU"], - "Type": "Tag", # Choose one or more tags - }, + "Device": {"Values": ["CPU", "GPU"], "Type": "Tag",}, # Choose one or more tags "Scenario": { "Values": [ "Business", @@ -96,14 +90,8 @@ semantic_config = { ], "Type": "Tag", # Choose one or more tags }, - "Description": { - "Values": None, - "Type": "Description", - }, - "Name": { - "Values": None, - "Type": "Name", - }, + "Description": {"Values": None, "Type": "Description",}, + "Name": {"Values": None, "Type": "Name",}, } _DEFAULT_CONFIG = { diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index d1d3b8e..acd719b 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -28,10 +28,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: The contructed learnware object, return None if build failed """ learnware_config = { - "model": { - "class_name": "Model", - "kwargs": {}, - }, + "model": {"class_name": "Model", "kwargs": {},}, "stat_specifications": [ { "module_path": "learnware.specification", diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 1748bce..ea13f3d 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -117,10 +117,7 @@ class EasyMarket(BaseMarket): self.learnware_folder_list[id] = target_folder_dir self.count += 1 add_learnware_to_db( - id, - semantic_spec=semantic_spec, - zip_path=target_folder_dir, - folder_path=target_folder_dir, + id, semantic_spec=semantic_spec, zip_path=target_folder_dir, folder_path=target_folder_dir, ) return id, True @@ -319,9 +316,9 @@ class EasyMarket(BaseMarket): self, learnware_list: List[Learnware], user_info: BaseUserInfo ) -> List[Learnware]: user_semantic_spec = user_info.get_semantic_spec() - user_input_description = user_semantic_spec["Description"]["Values"] + user_input_description = user_semantic_spec["Name"]["Values"] if not user_input_description: - return [] + return learnware_list match_learnwares = [] for learnware in learnware_list: learnware_semantic_spec = learnware.get_specification().get_semantic_spec() @@ -338,10 +335,21 @@ class EasyMarket(BaseMarket): return False for key in semantic_spec1.keys(): if semantic_spec1[key]["Type"] == "Class": - if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]: + if ( + len(semantic_spec2[key]["Values"]) > 0 + and semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"] + ): return False elif semantic_spec1[key]["Type"] == "Tag": - if not (set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"])): + if len(semantic_spec2[key]["Values"]) > 0 and not ( + set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"]) + ): + return False + elif semantic_spec1[key]["Type"] == "Name": + if ( + len(semantic_spec2[key]["Values"]) > 0 + and semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"] + ): return False return True @@ -373,9 +381,8 @@ class EasyMarket(BaseMarket): the third is the list of Learnware (mixture), the size is search_num """ learnware_list = [self.learnware_list[key] for key in self.learnware_list] - learnware_list_tags = self._search_by_semantic_tags(learnware_list, user_info) - learnware_list_description = self._search_by_semantic_description(learnware_list, user_info) - learnware_list = list(set(learnware_list_tags + learnware_list_description)) + learnware_list = self._search_by_semantic_tags(learnware_list, user_info) + # learnware_list = list(set(learnware_list_tags + learnware_list_description)) if "RKMEStatSpecification" not in user_info.stat_info: return None, learnware_list, None diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index 9668396..568c410 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -255,9 +255,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), + rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), separators=(",", ":"), ) def load(self, filepath: str) -> bool: @@ -345,7 +343,7 @@ def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: """ x1 = x1.double() x2 = x2.double() - X12norm = torch.sum(x1**2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2**2, 1, keepdim=True).T + X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T return torch.exp(-X12norm * gamma) From 7b871c5f0d12008ed645896f986a258031118cab Mon Sep 17 00:00:00 2001 From: xiey Date: Mon, 10 Apr 2023 21:27:20 +0800 Subject: [PATCH 5/5] [MNT] Allow semantic values to be eithor str or list --- learnware/market/easy.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index ea13f3d..13abcac 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -312,21 +312,6 @@ class EasyMarket(BaseMarket): return sorted_dist_list, sorted_learnware_list - def _search_by_semantic_description( - self, learnware_list: List[Learnware], user_info: BaseUserInfo - ) -> List[Learnware]: - user_semantic_spec = user_info.get_semantic_spec() - user_input_description = user_semantic_spec["Name"]["Values"] - if not user_input_description: - return learnware_list - match_learnwares = [] - for learnware in learnware_list: - learnware_semantic_spec = learnware.get_specification().get_semantic_spec() - learnware_name = learnware_semantic_spec["Name"]["Values"] - if user_input_description in learnware_name: - match_learnwares.append(learnware) - return match_learnwares - def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: def match_semantic_tags(semantic_spec1, semantic_spec2): if semantic_spec1.keys() != semantic_spec2.keys(): @@ -334,22 +319,22 @@ class EasyMarket(BaseMarket): logger.warning("semantic_spec key error!") return False for key in semantic_spec1.keys(): + if len(semantic_spec1[key]["Values"]) == 0: + continue + if len(semantic_spec2[key]["Values"]) == 0: + continue if semantic_spec1[key]["Type"] == "Class": - if ( - len(semantic_spec2[key]["Values"]) > 0 - and semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"] - ): + if isinstance(semantic_spec1[key]["Values"], list): + semantic_spec1[key]["Values"] = semantic_spec1[key]["Values"][0] + if isinstance(semantic_spec2[key]["Values"], list): + semantic_spec2[key]["Values"] = semantic_spec2[key]["Values"][0] + if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]: return False elif semantic_spec1[key]["Type"] == "Tag": - if len(semantic_spec2[key]["Values"]) > 0 and not ( - set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"]) - ): + if not (set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"])): return False elif semantic_spec1[key]["Type"] == "Name": - if ( - len(semantic_spec2[key]["Values"]) > 0 - and semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"] - ): + if semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"]: return False return True