diff --git a/tests/test_text_workflow/example_files/example_init.py b/examples/dataset_text_workflow/example_files/example_init.py similarity index 100% rename from tests/test_text_workflow/example_files/example_init.py rename to examples/dataset_text_workflow/example_files/example_init.py diff --git a/tests/test_text_workflow/example_files/example_yaml.yaml b/examples/dataset_text_workflow/example_files/example_yaml.yaml similarity index 100% rename from tests/test_text_workflow/example_files/example_yaml.yaml rename to examples/dataset_text_workflow/example_files/example_yaml.yaml diff --git a/tests/test_text_workflow/get_data.py b/examples/dataset_text_workflow/get_data.py similarity index 100% rename from tests/test_text_workflow/get_data.py rename to examples/dataset_text_workflow/get_data.py diff --git a/tests/test_text_workflow/main.py b/examples/dataset_text_workflow/main.py similarity index 99% rename from tests/test_text_workflow/main.py rename to examples/dataset_text_workflow/main.py index 9ae39ba..e7e1c38 100644 --- a/tests/test_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -1,6 +1,6 @@ import numpy as np import torch -from get_data import * +from get_data import get_sst2 import os import random from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction diff --git a/tests/test_text_workflow/utils.py b/examples/dataset_text_workflow/utils.py similarity index 100% rename from tests/test_text_workflow/utils.py rename to examples/dataset_text_workflow/utils.py diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 4b67157..f9ae5df 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -5,6 +5,7 @@ import random import string from ..base import BaseChecker +from ..utils import parse_specification_type from ...config import C from ...logger import get_module_logger @@ -61,6 +62,22 @@ class EasySemanticChecker(BaseChecker): class EasyStatChecker(BaseChecker): + @staticmethod + def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() @@ -76,41 +93,32 @@ class EasyStatChecker(BaseChecker): try: learnware_model = learnware.get_model() # Check input shape - if semantic_spec["Data"]["Values"][0] == "Table": - input_shape = (semantic_spec["Input"]["Dimension"],) - else: - input_shape = learnware_model.input_shape + input_shape = learnware_model.input_shape - # Check rkme dimension - is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec - if is_text: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") - else: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") - if stat_spec is not None and not is_text: + ## WHY: why write this? + if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( + int(semantic_spec["Input"]["Dimension"]), + ): + logger.warning("input shapes of model and semantic specifications are different") + return self.INVALID_LEARNWARE + + spec_type = parse_specification_type(learnware.get_specification()) + if spec_type is None: + logger.warning(f"No valid specification is found in stat spec {spec_type}") + return self.INVALID_LEARNWARE + + if spec_type == "RKMETableSpecification": + stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - - def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): - text_list = [] - for i in range(num): - length = random.randint(min_len, max_len) - if text_type == "en": - characters = string.ascii_letters + string.digits + string.punctuation - result_str = "".join(random.choice(characters) for i in range(length)) - text_list.append(result_str) - elif text_type == "zh": - result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) - text_list.append(result_str) - else: - raise ValueError("Type should be en or zh") - return text_list - - if is_text: - inputs = generate_random_text_list(10) - else: inputs = np.random.randn(10, *input_shape) + elif spec_type == "RKMETextSpecification": + inputs = EasyStatChecker._generate_random_text_list(10) + elif spec_type == "RKMEImageSpecification": + inputs = np.random.randint(0, 255, size=(10, *input_shape)) + else: + raise ValueError(f"not supported spec type for spec_type = {spec_type}") outputs = learnware.predict(inputs) # Check output if outputs.ndim == 1: @@ -129,8 +137,9 @@ class EasyStatChecker(BaseChecker): return self.INVALID_LEARNWARE # Check output shape - output_dim = int(semantic_spec["Output"]["Dimension"]) - if outputs[0].shape[0] != output_dim: + if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int( + semantic_spec["Output"]["Dimension"] + ): logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") return self.INVALID_LEARNWARE diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 106d47b..8feb5d9 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -5,9 +5,10 @@ from cvxopt import solvers, matrix from typing import Tuple, List, Union from .organizer import EasyOrganizer +from ..utils import parse_specification_type from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMETableSpecification, RKMEImageSpecification +from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -251,7 +252,7 @@ class EasyStatSearcher(BaseSearcher): The second is the mmd dist between the mixture of learnware rkmes and the user's rkme """ learnware_num = len(learnware_list) - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] if type(intermediate_K) == np.ndarray: K = intermediate_K @@ -318,7 +319,7 @@ class EasyStatSearcher(BaseSearcher): The second is the intermediate value of C """ num = intermediate_K.shape[0] - 1 - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) @@ -373,7 +374,7 @@ class EasyStatSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name)) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type)) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -414,16 +415,18 @@ class EasyStatSearcher(BaseSearcher): idx = idx + 1 return sorted_score_list[:idx], learnware_list[:idx] - def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] + def _filter_by_rkme_spec_metadata( + self, + learnware_list: List[Learnware], + user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> List[Learnware]: - """Filter learnwares whose rkme dimension different from user_rkme + """Filter learnwares whose rkme metadata different from user_rkme Parameters ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] User RKME statistical specification Returns @@ -435,12 +438,15 @@ class EasyStatSearcher(BaseSearcher): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - if self.stat_info_name not in learnware.specification.stat_spec: + if self.stat_spec_type not in learnware.specification.stat_spec: + continue + rkme = learnware.specification.get_stat_spec_by_name(self.stat_spec_type) + if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset( + set(rkme.language) + ): continue - rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) - if self.stat_info_name == "RKMETextSpecification": - if not set(user_rkme.language).issubset(set(rkme.language)): - continue + + # TODO: must we check dim for Text and Image specification? rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -520,7 +526,9 @@ class EasyStatSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] + self, + learnware_list: List[Learnware], + user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -528,7 +536,7 @@ class EasyStatSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] user RKME statistical specification Returns @@ -538,7 +546,7 @@ class EasyStatSearcher(BaseSearcher): the second is the list of Learnware both lists are sorted by mmd dist """ - RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] + RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] mmd_dist_list = [] for RKME in RKME_list: mmd_dist = RKME.dist(user_rkme) @@ -557,12 +565,12 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - if "RKMETextSpecification" in user_info.stat_info: - self.stat_info_name = "RKMETextSpecification" - else: - self.stat_info_name = "RKMETableSpecification" - user_rkme = user_info.stat_info[self.stat_info_name] - learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) + self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info) + if self.stat_spec_type is None: + raise KeyError("No supported stat specification is given in the user info") + + user_rkme = user_info.stat_info[self.stat_spec_type] + learnware_list = self._filter_by_rkme_spec_metadata(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) @@ -637,9 +645,8 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] - elif "RKMETableSpecification" in user_info.stat_info: - return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) - elif "RKMETextSpecification" in user_info.stat_info: + + if parse_specification_type(stat_spec=user_info.stat_info) is not None: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/market/utils.py b/learnware/market/utils.py new file mode 100644 index 0000000..c0cc319 --- /dev/null +++ b/learnware/market/utils.py @@ -0,0 +1,11 @@ +from ..specification import Specification + + +def parse_specification_type( + stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] +): + stat_specs = stat_spec.stat_spec + for spec in spec_list: + if spec in stat_specs: + return spec + return None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 6c37c8a..7503b4a 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -1,15 +1,17 @@ import torch import numpy as np -from typing import List +from typing import List, Union from cvxopt import matrix, solvers from lightgbm import LGBMClassifier, early_stopping from sklearn.metrics import accuracy_score -from learnware.learnware import Learnware -import learnware.specification as specification + from .base import BaseReuser +from ..market.utils import parse_specification_type +from ..learnware import Learnware from ..specification import RKMETableSpecification, RKMETextSpecification +from ..specification.utils import generate_rkme_spec from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -32,7 +34,7 @@ class JobSelectorReuser(BaseReuser): self.herding_num = herding_num self.use_herding = use_herding - def predict(self, user_data: np.ndarray) -> np.ndarray: + def predict(self, user_data: Union[np.ndarray, List[str]]) -> np.ndarray: """Give prediction for user data using baseline job-selector method Parameters @@ -41,12 +43,16 @@ class JobSelectorReuser(BaseReuser): User's unlabeled raw data. Returns - ------- + ------ np.ndarray Prediction given by job-selector method """ - ori_user_data = user_data + raw_user_data = user_data if isinstance(user_data[0], str): + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) + assert ( + stat_spec_type == "RKMETextSpecification" + ), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." user_data = RKMETextSpecification.get_sentence_embedding(user_data) select_result = self.job_selector(user_data) @@ -56,8 +62,8 @@ class JobSelectorReuser(BaseReuser): for idx in range(len(self.learnware_list)): data_idx_list = np.where(select_result == idx)[0] if len(data_idx_list) > 0: - # pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list]) - pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list]) + # pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list]) + pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list]) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() # elif isinstance(pred_y, tf.Tensor): @@ -91,14 +97,9 @@ class JobSelectorReuser(BaseReuser): user_data_num = len(user_data) return np.array([0] * user_data_num) else: - ori_user_data = user_data - if isinstance(user_data[0], str): - user_data = RKMETextSpecification.get_sentence_embedding(user_data) - spec_name = "RKMETableSpecification" - if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec: - spec_name = "RKMETextSpecification" + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list + learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list ] if self.use_herding: @@ -179,9 +180,7 @@ class JobSelectorReuser(BaseReuser): Inner product matrix calculated from task_rkme_list. """ task_num = len(task_rkme_list) - if isinstance(user_data[0], str): - user_data = RKMETextSpecification.get_sentence_embedding(user_data) - user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) + user_rkme_spec = generate_rkme_spec(X=user_data, reduce=False) K = task_rkme_matrix v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 117b032..95dc7f5 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -1,8 +1,8 @@ -from sentence_transformers import SentenceTransformer -from ..table import RKMETableSpecification -import numpy as np import os import langdetect +import numpy as np +from sentence_transformers import SentenceTransformer +from ..table import RKMETableSpecification from ....logger import get_module_logger logger = get_module_logger("RKMETextSpecification", "INFO") diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index f4507c5..fea00d9 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -18,7 +18,7 @@ import learnware.specification as specification curr_root = os.path.dirname(os.path.abspath(__file__)) user_semantic = { - "Data": {"Values": ["Image"], "Type": "Class"}, + "Data": {"Values": ["Table"], "Type": "Class"}, "Task": { "Values": ["Classification"], "Type": "Class", @@ -96,6 +96,10 @@ class TestAllWorkflow(unittest.TestCase): semantic_spec = copy.deepcopy(user_semantic) semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) + semantic_spec["Input"] = {"Dimension": 64} + semantic_spec["Input"].update( + {f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)} + ) semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}} easy_market.add_learnware(zip_path, semantic_spec)