diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index bdcbe53..e620e66 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -14,10 +14,7 @@ curr_root = os.path.dirname(os.path.abspath(__file__)) semantic_specs = [ { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -25,10 +22,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -36,10 +30,7 @@ semantic_specs = [ }, { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Regression"], "Type": "Class"}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, @@ -49,14 +40,11 @@ semantic_specs = [ user_senmantic = { "Data": {"Values": ["Tabular"], "Type": "Class"}, - "Task": { - "Values": ["Classification"], - "Type": "Class", - }, + "Task": {"Values": ["Classification"], "Type": "Class",}, "Device": {"Values": ["GPU"], "Type": "Tag"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "Description"}, - "Name": {"Values": "", "Type": "Name"}, + "Name": {"Values": "learnware_4", "Type": "Name"}, } @@ -117,7 +105,7 @@ def test_market(): print("Available ids:", curr_inds) -def test_search_sementics(): +def test_search_semantics(): easy_market = EasyMarket() print("Total Item:", len(easy_market)) @@ -129,15 +117,20 @@ def test_search_sementics(): test_folder = "./test_stat" zip_path_list = get_zip_path_list() - for idx, zip_path in enumerate(zip_path_list): - unzip_dir = os.path.join(test_folder, f"{idx}") - os.makedirs(unzip_dir, exist_ok=True) - os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") + idx, zip_path = 1, zip_path_list[1] + unzip_dir = os.path.join(test_folder, f"{idx}") + os.makedirs(unzip_dir, exist_ok=True) + os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") - user_spec = specification.rkme.RKMEStatSpecification() - user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic, stat_info={"RKME": user_spec}) - sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + user_spec = specification.rkme.RKMEStatSpecification() + user_spec.load(os.path.join(unzip_dir, "svm.json")) + user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic) + _, single_learnware_list, _ = easy_market.search_learnware(user_info) + + print("User info:", user_info.get_semantic_spec()) + print(f"search result of user{idx}:") + for learnware in single_learnware_list: + print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) os.system(f"rm -r {test_folder}") @@ -174,5 +167,5 @@ if __name__ == "__main__": learnware_num = 10 prepare_learnware(learnware_num) test_market() - test_stat_search() - test_search_sementics() + # test_stat_search() + test_search_semantics() diff --git a/learnware/config.py b/learnware/config.py index f2c650b..5458374 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -57,10 +57,7 @@ os.makedirs(LEARNWARE_ZIP_POOL_PATH, exist_ok=True) os.makedirs(LEARNWARE_FOLDER_POOL_PATH, exist_ok=True) semantic_config = { - "Data": { - "Values": ["Tabular", "Image", "Video", "Text", "Audio"], - "Type": "Class", # Choose only one class - }, + "Data": {"Values": ["Tabular", "Image", "Video", "Text", "Audio"], "Type": "Class",}, # Choose only one class "Task": { "Values": [ "Classification", @@ -73,10 +70,7 @@ semantic_config = { ], "Type": "Class", # Choose only one class }, - "Device": { - "Values": ["CPU", "GPU"], - "Type": "Tag", # Choose one or more tags - }, + "Device": {"Values": ["CPU", "GPU"], "Type": "Tag",}, # Choose one or more tags "Scenario": { "Values": [ "Business", @@ -96,14 +90,8 @@ semantic_config = { ], "Type": "Tag", # Choose one or more tags }, - "Description": { - "Values": None, - "Type": "Description", - }, - "Name": { - "Values": None, - "Type": "Name", - }, + "Description": {"Values": None, "Type": "Description",}, + "Name": {"Values": None, "Type": "Name",}, } _DEFAULT_CONFIG = { diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index d1d3b8e..acd719b 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -28,10 +28,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: The contructed learnware object, return None if build failed """ learnware_config = { - "model": { - "class_name": "Model", - "kwargs": {}, - }, + "model": {"class_name": "Model", "kwargs": {},}, "stat_specifications": [ { "module_path": "learnware.specification", diff --git a/learnware/market/easy.py b/learnware/market/easy.py index a472a85..1472e95 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -117,10 +117,7 @@ class EasyMarket(BaseMarket): self.learnware_folder_list[id] = target_folder_dir self.count += 1 add_learnware_to_db( - id, - semantic_spec=semantic_spec, - zip_path=target_folder_dir, - folder_path=target_folder_dir, + id, semantic_spec=semantic_spec, zip_path=target_folder_dir, folder_path=target_folder_dir, ) return id, True @@ -333,21 +330,6 @@ class EasyMarket(BaseMarket): return sorted_dist_list, sorted_learnware_list - def _search_by_semantic_description( - self, learnware_list: List[Learnware], user_info: BaseUserInfo - ) -> List[Learnware]: - user_semantic_spec = user_info.get_semantic_spec() - user_input_description = user_semantic_spec["Description"]["Values"] - if not user_input_description: - return [] - match_learnwares = [] - for learnware in learnware_list: - learnware_semantic_spec = learnware.get_specification().get_semantic_spec() - learnware_name = learnware_semantic_spec["Name"]["Values"] - if user_input_description in learnware_name: - match_learnwares.append(learnware) - return match_learnwares - def _search_by_semantic_tags(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: def match_semantic_tags(semantic_spec1, semantic_spec2): if semantic_spec1.keys() != semantic_spec2.keys(): @@ -355,12 +337,23 @@ class EasyMarket(BaseMarket): logger.warning("semantic_spec key error!") return False for key in semantic_spec1.keys(): + if len(semantic_spec1[key]["Values"]) == 0: + continue + if len(semantic_spec2[key]["Values"]) == 0: + continue if semantic_spec1[key]["Type"] == "Class": + if isinstance(semantic_spec1[key]["Values"], list): + semantic_spec1[key]["Values"] = semantic_spec1[key]["Values"][0] + if isinstance(semantic_spec2[key]["Values"], list): + semantic_spec2[key]["Values"] = semantic_spec2[key]["Values"][0] if semantic_spec1[key]["Values"] != semantic_spec2[key]["Values"]: return False elif semantic_spec1[key]["Type"] == "Tag": if not (set(semantic_spec1[key]["Values"]) & set(semantic_spec2[key]["Values"])): return False + elif semantic_spec1[key]["Type"] == "Name": + if semantic_spec2[key]["Values"] not in semantic_spec1[key]["Values"]: + return False return True match_learnwares = [] @@ -391,9 +384,8 @@ class EasyMarket(BaseMarket): the third is the list of Learnware (mixture), the size is search_num """ learnware_list = [self.learnware_list[key] for key in self.learnware_list] - learnware_list_tags = self._search_by_semantic_tags(learnware_list, user_info) - learnware_list_description = self._search_by_semantic_description(learnware_list, user_info) - learnware_list = list(set(learnware_list_tags + learnware_list_description)) + learnware_list = self._search_by_semantic_tags(learnware_list, user_info) + # learnware_list = list(set(learnware_list_tags + learnware_list_description)) if "RKMEStatSpecification" not in user_info.stat_info: return None, learnware_list, None diff --git a/learnware/specification/base.py b/learnware/specification/base.py index e5042b4..449cfd3 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -6,7 +6,12 @@ class BaseStatSpecification: def __init__(self): pass - def generate_stat_spec_from_data(self, X: np.ndarray): + def generate_stat_spec_from_data(self, **kwargs): + """Construct statistical specification from raw dataset + + - kwargs may include the feature, label and model + - kwargs also can include hyperparameters of specific method for specifaction generation + """ raise NotImplementedError("generate_stat_spec_from_data is not implemented") def save(self, filepath: str): diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index 9668396..568c410 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -255,9 +255,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), + rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), separators=(",", ":"), ) def load(self, filepath: str) -> bool: @@ -345,7 +343,7 @@ def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: """ x1 = x1.double() x2 = x2.double() - X12norm = torch.sum(x1**2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2**2, 1, keepdim=True).T + X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T return torch.exp(-X12norm * gamma)