From 939e6cfab4184c011d195cce4639e95d2c96bedc Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 2 Nov 2023 00:48:48 +0800 Subject: [PATCH 01/10] [MNT] move text workflow to example --- .../dataset_text_workflow}/example_files/example_init.py | 0 .../dataset_text_workflow}/example_files/example_yaml.yaml | 0 .../dataset_text_workflow}/get_data.py | 0 .../dataset_text_workflow}/main.py | 2 +- .../dataset_text_workflow}/utils.py | 0 5 files changed, 1 insertion(+), 1 deletion(-) rename {tests/test_text_workflow => examples/dataset_text_workflow}/example_files/example_init.py (100%) rename {tests/test_text_workflow => examples/dataset_text_workflow}/example_files/example_yaml.yaml (100%) rename {tests/test_text_workflow => examples/dataset_text_workflow}/get_data.py (100%) rename {tests/test_text_workflow => examples/dataset_text_workflow}/main.py (99%) rename {tests/test_text_workflow => examples/dataset_text_workflow}/utils.py (100%) diff --git a/tests/test_text_workflow/example_files/example_init.py b/examples/dataset_text_workflow/example_files/example_init.py similarity index 100% rename from tests/test_text_workflow/example_files/example_init.py rename to examples/dataset_text_workflow/example_files/example_init.py diff --git a/tests/test_text_workflow/example_files/example_yaml.yaml b/examples/dataset_text_workflow/example_files/example_yaml.yaml similarity index 100% rename from tests/test_text_workflow/example_files/example_yaml.yaml rename to examples/dataset_text_workflow/example_files/example_yaml.yaml diff --git a/tests/test_text_workflow/get_data.py b/examples/dataset_text_workflow/get_data.py similarity index 100% rename from tests/test_text_workflow/get_data.py rename to examples/dataset_text_workflow/get_data.py diff --git a/tests/test_text_workflow/main.py b/examples/dataset_text_workflow/main.py similarity index 99% rename from tests/test_text_workflow/main.py rename to examples/dataset_text_workflow/main.py index baa54f4..6c712b7 100644 --- a/tests/test_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -1,6 +1,6 @@ import numpy as np import torch -from get_data import * +from get_data import get_sst2 import os import random from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction diff --git a/tests/test_text_workflow/utils.py b/examples/dataset_text_workflow/utils.py similarity index 100% rename from tests/test_text_workflow/utils.py rename to examples/dataset_text_workflow/utils.py From 01836616b9626784d45d5579542152b14c4e9b72 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 2 Nov 2023 01:39:58 +0800 Subject: [PATCH 02/10] [MNT] update typehint in market --- learnware/market/easy2/searcher.py | 10 +++++----- learnware/specification/regular/text/rkme.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 8934fda..4349869 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -7,7 +7,7 @@ from typing import Tuple, List, Union from .organizer import EasyOrganizer from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMETableSpecification, RKMEImageSpecification +from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -415,7 +415,7 @@ class EasyStatSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -423,7 +423,7 @@ class EasyStatSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] User RKME statistical specification Returns @@ -520,7 +520,7 @@ class EasyStatSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -528,7 +528,7 @@ class EasyStatSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] user RKME statistical specification Returns diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index cc8659e..0152f56 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -1,8 +1,8 @@ -from sentence_transformers import SentenceTransformer -from ..table import RKMETableSpecification -import numpy as np import os import langdetect +import numpy as np +from sentence_transformers import SentenceTransformer +from ..table import RKMETableSpecification from ....logger import get_module_logger logger = get_module_logger("RKMETextSpecification", "INFO") From d54a56d5aed2cfcfbbae27c8c5b58c8bc979f3a5 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 02:59:21 +0800 Subject: [PATCH 03/10] [FIX] fix checker logic with many specification --- learnware/market/easy2/checker.py | 72 +++++++++++++++++-------------- learnware/market/utils.py | 12 ++++++ 2 files changed, 51 insertions(+), 33 deletions(-) create mode 100644 learnware/market/utils.py diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index fa8f26c..2ab04f5 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -5,6 +5,7 @@ import random import string from ..base import BaseChecker +from ..utils import parse_specification_type from ...config import C from ...logger import get_module_logger @@ -61,6 +62,23 @@ class EasySemanticChecker(BaseChecker): class EasyStatisticalChecker(BaseChecker): + + @staticmethod + def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() @@ -76,41 +94,30 @@ class EasyStatisticalChecker(BaseChecker): try: learnware_model = learnware.get_model() # Check input shape - if semantic_spec["Data"]["Values"][0] == "Table": - input_shape = (semantic_spec["Input"]["Dimension"],) - else: - input_shape = learnware_model.input_shape - - # Check rkme dimension - is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec - if is_text: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") - else: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") - if stat_spec is not None and not is_text: + input_shape = learnware_model.input_shape + + ## WHY: why write this? + if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (int(semantic_spec["Input"]["Dimension"]),): + logger.warning("input shapes of model and semantic specifications are different") + return self.INVALID_LEARNWARE + + spec_type = parse_specification_type(learnware.get_specification()) + if spec_type is None: + logger.warning(f"No valid specification is found in stat spec {stat_spec}") + return self.INVALID_LEARNWARE + + if spec_type == "RKMETableSpecification": + stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - - def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): - text_list = [] - for i in range(num): - length = random.randint(min_len, max_len) - if text_type == "en": - characters = string.ascii_letters + string.digits + string.punctuation - result_str = "".join(random.choice(characters) for i in range(length)) - text_list.append(result_str) - elif text_type == "zh": - result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) - text_list.append(result_str) - else: - raise ValueError("Type should be en or zh") - return text_list - - if is_text: - inputs = generate_random_text_list(10) - else: inputs = np.random.randn(10, *input_shape) + elif spec_type == "RKMETextSpecification": + inputs = EasyStatisticalChecker._generate_random_text_list(10) + elif spec_type == "RKMEImageSpecification": + inputs = np.random.randint(0, 255, size=(10, *input_shape)) + else: + raise ValueError(f"not supported spec type for spec_type = {spec_type}") outputs = learnware.predict(inputs) # Check output if outputs.ndim == 1: @@ -129,8 +136,7 @@ class EasyStatisticalChecker(BaseChecker): return self.INVALID_LEARNWARE # Check output shape - output_dim = int(semantic_spec["Output"]["Dimension"]) - if outputs[0].shape[0] != output_dim: + if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") return self.INVALID_LEARNWARE diff --git a/learnware/market/utils.py b/learnware/market/utils.py new file mode 100644 index 0000000..f3d4e35 --- /dev/null +++ b/learnware/market/utils.py @@ -0,0 +1,12 @@ +from ..specification import Specification +from ..logger import get_module_logger + +logger = get_module_logger('market_utils') + +def parse_specification_type(stat_spec: Specification): + stat_specs = stat_spec.stat_spec + spec_list =['RKMETableSpecification', 'RKMETextSpecification', 'RKMEImageSpecification'] + for spec in spec_list: + if spec in stat_specs: + return spec + return None \ No newline at end of file From 77f8dffa15c0e494fb4d9e1d7f62d7627e386384 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:05:40 +0800 Subject: [PATCH 04/10] [FIX] fix test workflow --- tests/test_workflow/test_workflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index f4507c5..718f65c 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -18,7 +18,7 @@ import learnware.specification as specification curr_root = os.path.dirname(os.path.abspath(__file__)) user_semantic = { - "Data": {"Values": ["Image"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, "Task": { "Values": ["Classification"], "Type": "Class", @@ -96,6 +96,8 @@ class TestAllWorkflow(unittest.TestCase): semantic_spec = copy.deepcopy(user_semantic) semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) + semantic_spec["Input"] = {"Dimension": 64} + semantic_spec["Input"].update({f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)}) semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}} easy_market.add_learnware(zip_path, semantic_spec) From d837ea5635b17e9978212a95db86d7cbf637e07f Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:06:12 +0800 Subject: [PATCH 05/10] [MNT] black format --- learnware/market/easy2/checker.py | 19 +++++++++++-------- learnware/market/easy2/searcher.py | 8 ++++++-- learnware/market/utils.py | 7 ++++--- tests/test_workflow/test_workflow.py | 4 +++- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 2ab04f5..ca8c25f 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -62,7 +62,6 @@ class EasySemanticChecker(BaseChecker): class EasyStatisticalChecker(BaseChecker): - @staticmethod def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): text_list = [] @@ -78,7 +77,7 @@ class EasyStatisticalChecker(BaseChecker): else: raise ValueError("Type should be en or zh") return text_list - + def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() @@ -97,15 +96,17 @@ class EasyStatisticalChecker(BaseChecker): input_shape = learnware_model.input_shape ## WHY: why write this? - if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (int(semantic_spec["Input"]["Dimension"]),): - logger.warning("input shapes of model and semantic specifications are different") - return self.INVALID_LEARNWARE - + if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( + int(semantic_spec["Input"]["Dimension"]), + ): + logger.warning("input shapes of model and semantic specifications are different") + return self.INVALID_LEARNWARE + spec_type = parse_specification_type(learnware.get_specification()) if spec_type is None: logger.warning(f"No valid specification is found in stat spec {stat_spec}") return self.INVALID_LEARNWARE - + if spec_type == "RKMETableSpecification": stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if stat_spec.get_z().shape[1:] != input_shape: @@ -136,7 +137,9 @@ class EasyStatisticalChecker(BaseChecker): return self.INVALID_LEARNWARE # Check output shape - if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): + if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int( + semantic_spec["Output"]["Dimension"] + ): logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") return self.INVALID_LEARNWARE diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 4349869..4276e86 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -415,7 +415,9 @@ class EasyStatSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] + self, + learnware_list: List[Learnware], + user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -520,7 +522,9 @@ class EasyStatSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] + self, + learnware_list: List[Learnware], + user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme diff --git a/learnware/market/utils.py b/learnware/market/utils.py index f3d4e35..90cdab3 100644 --- a/learnware/market/utils.py +++ b/learnware/market/utils.py @@ -1,12 +1,13 @@ from ..specification import Specification from ..logger import get_module_logger -logger = get_module_logger('market_utils') +logger = get_module_logger("market_utils") + def parse_specification_type(stat_spec: Specification): stat_specs = stat_spec.stat_spec - spec_list =['RKMETableSpecification', 'RKMETextSpecification', 'RKMEImageSpecification'] + spec_list = ["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] for spec in spec_list: if spec in stat_specs: return spec - return None \ No newline at end of file + return None diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 718f65c..fea00d9 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -97,7 +97,9 @@ class TestAllWorkflow(unittest.TestCase): semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) semantic_spec["Input"] = {"Dimension": 64} - semantic_spec["Input"].update({f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)}) + semantic_spec["Input"].update( + {f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)} + ) semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}} easy_market.add_learnware(zip_path, semantic_spec) From daef252f00334a0e3a277cba1faa46e3bc382529 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:22:54 +0800 Subject: [PATCH 06/10] [MNT] modify searcher to fit all specs --- learnware/market/easy2/searcher.py | 42 ++++++++++++++++-------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 4276e86..e9f69fd 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -5,6 +5,7 @@ from cvxopt import solvers, matrix from typing import Tuple, List, Union from .organizer import EasyOrganizer +from ..utils import parse_specification_type from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification @@ -251,7 +252,7 @@ class EasyStatSearcher(BaseSearcher): The second is the mmd dist between the mixture of learnware rkmes and the user's rkme """ learnware_num = len(learnware_list) - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] if type(intermediate_K) == np.ndarray: K = intermediate_K @@ -318,7 +319,7 @@ class EasyStatSearcher(BaseSearcher): The second is the intermediate value of C """ num = intermediate_K.shape[0] - 1 - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) @@ -373,7 +374,7 @@ class EasyStatSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name)) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type)) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -414,12 +415,12 @@ class EasyStatSearcher(BaseSearcher): idx = idx + 1 return sorted_score_list[:idx], learnware_list[:idx] - def _filter_by_rkme_spec_dimension( + def _filter_by_rkme_spec_dim( self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> List[Learnware]: - """Filter learnwares whose rkme dimension different from user_rkme + """Filter learnwares whose rkme dimention different from user_rkme Parameters ---------- @@ -437,12 +438,13 @@ class EasyStatSearcher(BaseSearcher): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - if self.stat_info_name not in learnware.specification.stat_spec: + if self.stat_spec_type not in learnware.specification.stat_spec: continue - rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) - if self.stat_info_name == "RKMETextSpecification": - if not set(user_rkme.language).issubset(set(rkme.language)): - continue + rkme = learnware.specification.get_stat_spec_by_name(self.stat_spec_type) + if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset(set(rkme.language)): + continue + + # TODO: must we check dim for Text and Image specification? rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -542,7 +544,7 @@ class EasyStatSearcher(BaseSearcher): the second is the list of Learnware both lists are sorted by mmd dist """ - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] mmd_dist_list = [] for RKME in RKME_list: mmd_dist = RKME.dist(user_rkme) @@ -561,12 +563,13 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - if "RKMETextSpecification" in user_info.stat_info: - self.stat_info_name = "RKMETextSpecification" - else: - self.stat_info_name = "RKMETableSpecification" - user_rkme = user_info.stat_info[self.stat_info_name] - learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) + + self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info) + if self.stat_spec_type is None: + raise KeyError("No supported stat specification is given in the user info") + + user_rkme = user_info.stat_info[self.stat_spec_type] + learnware_list = self._filter_by_rkme_spec_dim(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) @@ -638,9 +641,8 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] - elif "RKMETableSpecification" in user_info.stat_info: - return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) - elif "RKMETextSpecification" in user_info.stat_info: + + if parse_specification_type(stat_spec=user_info.stat_info) is not None: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None From bd4c226ade66405868db1aa93cb54c53ab52582b Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:25:02 +0800 Subject: [PATCH 07/10] [MNT] update searcher filer method name --- learnware/market/easy2/searcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index e9f69fd..0a2089e 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -415,12 +415,12 @@ class EasyStatSearcher(BaseSearcher): idx = idx + 1 return sorted_score_list[:idx], learnware_list[:idx] - def _filter_by_rkme_spec_dim( + def _filter_by_rkme_spec_metadata( self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> List[Learnware]: - """Filter learnwares whose rkme dimention different from user_rkme + """Filter learnwares whose rkme metadata different from user_rkme Parameters ---------- @@ -569,7 +569,7 @@ class EasyStatSearcher(BaseSearcher): raise KeyError("No supported stat specification is given in the user info") user_rkme = user_info.stat_info[self.stat_spec_type] - learnware_list = self._filter_by_rkme_spec_dim(learnware_list, user_rkme) + learnware_list = self._filter_by_rkme_spec_metadata(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) From a53e618fdeaf3de92460f94ff3b81e05d2e36c93 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:40:23 +0800 Subject: [PATCH 08/10] [MNT] fix reuser with many specifications --- learnware/market/utils.py | 7 +------ learnware/reuse/job_selector.py | 35 +++++++++++++++------------------ 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/learnware/market/utils.py b/learnware/market/utils.py index 90cdab3..95b27a9 100644 --- a/learnware/market/utils.py +++ b/learnware/market/utils.py @@ -1,12 +1,7 @@ from ..specification import Specification -from ..logger import get_module_logger -logger = get_module_logger("market_utils") - - -def parse_specification_type(stat_spec: Specification): +def parse_specification_type(stat_spec: Specification, spec_list = ["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"]): stat_specs = stat_spec.stat_spec - spec_list = ["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] for spec in spec_list: if spec in stat_specs: return spec diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 6c37c8a..aba6dd9 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -1,15 +1,17 @@ import torch import numpy as np -from typing import List +from typing import List, Union from cvxopt import matrix, solvers from lightgbm import LGBMClassifier, early_stopping from sklearn.metrics import accuracy_score -from learnware.learnware import Learnware -import learnware.specification as specification + from .base import BaseReuser +from ..market.utils import parse_specification_type +from ..learnware import Learnware from ..specification import RKMETableSpecification, RKMETextSpecification +from ..specification.utils import generate_rkme_spec from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -32,7 +34,7 @@ class JobSelectorReuser(BaseReuser): self.herding_num = herding_num self.use_herding = use_herding - def predict(self, user_data: np.ndarray) -> np.ndarray: + def predict(self, user_data: Union[np.ndarray, List[str]]) -> np.ndarray: """Give prediction for user data using baseline job-selector method Parameters @@ -41,12 +43,14 @@ class JobSelectorReuser(BaseReuser): User's unlabeled raw data. Returns - ------- + ------ np.ndarray Prediction given by job-selector method """ - ori_user_data = user_data + raw_user_data = user_data if isinstance(user_data[0], str): + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) + assert stat_spec_type == 'RKMETextSpecification', "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." user_data = RKMETextSpecification.get_sentence_embedding(user_data) select_result = self.job_selector(user_data) @@ -56,8 +60,8 @@ class JobSelectorReuser(BaseReuser): for idx in range(len(self.learnware_list)): data_idx_list = np.where(select_result == idx)[0] if len(data_idx_list) > 0: - # pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list]) - pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list]) + # pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list]) + pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list]) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() # elif isinstance(pred_y, tf.Tensor): @@ -90,15 +94,10 @@ class JobSelectorReuser(BaseReuser): # user_data_num = user_data.shape[0] user_data_num = len(user_data) return np.array([0] * user_data_num) - else: - ori_user_data = user_data - if isinstance(user_data[0], str): - user_data = RKMETextSpecification.get_sentence_embedding(user_data) - spec_name = "RKMETableSpecification" - if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec: - spec_name = "RKMETextSpecification" + else: + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list + learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list ] if self.use_herding: @@ -179,9 +178,7 @@ class JobSelectorReuser(BaseReuser): Inner product matrix calculated from task_rkme_list. """ task_num = len(task_rkme_list) - if isinstance(user_data[0], str): - user_data = RKMETextSpecification.get_sentence_embedding(user_data) - user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) + user_rkme_spec = generate_rkme_spec(X=user_data, reduce=False) K = task_rkme_matrix v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) From b4f001e982d5123ba9035a39e5c1851a91fadc71 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:40:54 +0800 Subject: [PATCH 09/10] [MNT] black format --- learnware/market/easy2/searcher.py | 9 +++++---- learnware/market/utils.py | 5 ++++- learnware/reuse/job_selector.py | 6 ++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 0a2089e..e07f861 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -441,9 +441,11 @@ class EasyStatSearcher(BaseSearcher): if self.stat_spec_type not in learnware.specification.stat_spec: continue rkme = learnware.specification.get_stat_spec_by_name(self.stat_spec_type) - if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset(set(rkme.language)): + if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset( + set(rkme.language) + ): continue - + # TODO: must we check dim for Text and Image specification? rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: @@ -563,7 +565,6 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info) if self.stat_spec_type is None: raise KeyError("No supported stat specification is given in the user info") @@ -641,7 +642,7 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] - + if parse_specification_type(stat_spec=user_info.stat_info) is not None: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: diff --git a/learnware/market/utils.py b/learnware/market/utils.py index 95b27a9..c0cc319 100644 --- a/learnware/market/utils.py +++ b/learnware/market/utils.py @@ -1,6 +1,9 @@ from ..specification import Specification -def parse_specification_type(stat_spec: Specification, spec_list = ["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"]): + +def parse_specification_type( + stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] +): stat_specs = stat_spec.stat_spec for spec in spec_list: if spec in stat_specs: diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index aba6dd9..7503b4a 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -50,7 +50,9 @@ class JobSelectorReuser(BaseReuser): raw_user_data = user_data if isinstance(user_data[0], str): stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) - assert stat_spec_type == 'RKMETextSpecification', "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." + assert ( + stat_spec_type == "RKMETextSpecification" + ), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." user_data = RKMETextSpecification.get_sentence_embedding(user_data) select_result = self.job_selector(user_data) @@ -94,7 +96,7 @@ class JobSelectorReuser(BaseReuser): # user_data_num = user_data.shape[0] user_data_num = len(user_data) return np.array([0] * user_data_num) - else: + else: stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list From d742a217c38f64c0df7b69eddb218b252aeea348 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 03:46:04 +0800 Subject: [PATCH 10/10] [FIX] fix semantic error with flake8 --- learnware/market/easy2/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index ca8c25f..dadf2a4 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -104,7 +104,7 @@ class EasyStatisticalChecker(BaseChecker): spec_type = parse_specification_type(learnware.get_specification()) if spec_type is None: - logger.warning(f"No valid specification is found in stat spec {stat_spec}") + logger.warning(f"No valid specification is found in stat spec {spec_type}") return self.INVALID_LEARNWARE if spec_type == "RKMETableSpecification":