Fit Many Specificationtags/v0.3.2
| @@ -1,6 +1,6 @@ | |||
| import numpy as np | |||
| import torch | |||
| from get_data import * | |||
| from get_data import get_sst2 | |||
| import os | |||
| import random | |||
| from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction | |||
| @@ -5,6 +5,7 @@ import random | |||
| import string | |||
| from ..base import BaseChecker | |||
| from ..utils import parse_specification_type | |||
| from ...config import C | |||
| from ...logger import get_module_logger | |||
| @@ -61,6 +62,22 @@ class EasySemanticChecker(BaseChecker): | |||
| class EasyStatChecker(BaseChecker): | |||
| @staticmethod | |||
| def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| @@ -76,41 +93,32 @@ class EasyStatChecker(BaseChecker): | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # Check input shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| input_shape = (semantic_spec["Input"]["Dimension"],) | |||
| else: | |||
| input_shape = learnware_model.input_shape | |||
| input_shape = learnware_model.input_shape | |||
| # Check rkme dimension | |||
| is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec | |||
| if is_text: | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") | |||
| else: | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") | |||
| if stat_spec is not None and not is_text: | |||
| ## WHY: why write this? | |||
| if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | |||
| int(semantic_spec["Input"]["Dimension"]), | |||
| ): | |||
| logger.warning("input shapes of model and semantic specifications are different") | |||
| return self.INVALID_LEARNWARE | |||
| spec_type = parse_specification_type(learnware.get_specification()) | |||
| if spec_type is None: | |||
| logger.warning(f"No valid specification is found in stat spec {spec_type}") | |||
| return self.INVALID_LEARNWARE | |||
| if spec_type == "RKMETableSpecification": | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | |||
| return self.INVALID_LEARNWARE | |||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| if is_text: | |||
| inputs = generate_random_text_list(10) | |||
| else: | |||
| inputs = np.random.randn(10, *input_shape) | |||
| elif spec_type == "RKMETextSpecification": | |||
| inputs = EasyStatChecker._generate_random_text_list(10) | |||
| elif spec_type == "RKMEImageSpecification": | |||
| inputs = np.random.randint(0, 255, size=(10, *input_shape)) | |||
| else: | |||
| raise ValueError(f"not supported spec type for spec_type = {spec_type}") | |||
| outputs = learnware.predict(inputs) | |||
| # Check output | |||
| if outputs.ndim == 1: | |||
| @@ -129,8 +137,9 @@ class EasyStatChecker(BaseChecker): | |||
| return self.INVALID_LEARNWARE | |||
| # Check output shape | |||
| output_dim = int(semantic_spec["Output"]["Dimension"]) | |||
| if outputs[0].shape[0] != output_dim: | |||
| if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int( | |||
| semantic_spec["Output"]["Dimension"] | |||
| ): | |||
| logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") | |||
| return self.INVALID_LEARNWARE | |||
| @@ -5,9 +5,10 @@ from cvxopt import solvers, matrix | |||
| from typing import Tuple, List, Union | |||
| from .organizer import EasyOrganizer | |||
| from ..utils import parse_specification_type | |||
| from ..base import BaseUserInfo, BaseSearcher | |||
| from ...learnware import Learnware | |||
| from ...specification import RKMETableSpecification, RKMEImageSpecification | |||
| from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_seacher") | |||
| @@ -251,7 +252,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| The second is the mmd dist between the mixture of learnware rkmes and the user's rkme | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] | |||
| if type(intermediate_K) == np.ndarray: | |||
| K = intermediate_K | |||
| @@ -318,7 +319,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| The second is the intermediate value of C | |||
| """ | |||
| num = intermediate_K.shape[0] - 1 | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] | |||
| for i in range(intermediate_K.shape[0]): | |||
| intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | |||
| intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) | |||
| @@ -373,7 +374,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| if len(mixture_list) <= 1: | |||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | |||
| mixture_weight = [1] | |||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_info_name)) | |||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type)) | |||
| else: | |||
| if len(mixture_list) > max_search_num: | |||
| mixture_list = mixture_list[:max_search_num] | |||
| @@ -414,16 +415,18 @@ class EasyStatSearcher(BaseSearcher): | |||
| idx = idx + 1 | |||
| return sorted_score_list[:idx], learnware_list[:idx] | |||
| def _filter_by_rkme_spec_dimension( | |||
| self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] | |||
| def _filter_by_rkme_spec_metadata( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], | |||
| ) -> List[Learnware]: | |||
| """Filter learnwares whose rkme dimension different from user_rkme | |||
| """Filter learnwares whose rkme metadata different from user_rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] | |||
| User RKME statistical specification | |||
| Returns | |||
| @@ -435,12 +438,15 @@ class EasyStatSearcher(BaseSearcher): | |||
| user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | |||
| for learnware in learnware_list: | |||
| if self.stat_info_name not in learnware.specification.stat_spec: | |||
| if self.stat_spec_type not in learnware.specification.stat_spec: | |||
| continue | |||
| rkme = learnware.specification.get_stat_spec_by_name(self.stat_spec_type) | |||
| if self.stat_spec_type == "RKMETextSpecification" and not set(user_rkme.language).issubset( | |||
| set(rkme.language) | |||
| ): | |||
| continue | |||
| rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) | |||
| if self.stat_info_name == "RKMETextSpecification": | |||
| if not set(user_rkme.language).issubset(set(rkme.language)): | |||
| continue | |||
| # TODO: must we check dim for Text and Image specification? | |||
| rkme_dim = str(list(rkme.get_z().shape)[1:]) | |||
| if rkme_dim == user_rkme_dim: | |||
| filtered_learnware_list.append(learnware) | |||
| @@ -520,7 +526,9 @@ class EasyStatSearcher(BaseSearcher): | |||
| return mmd_dist, weight_min, mixture_list | |||
| def _search_by_rkme_spec_single( | |||
| self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification], | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | |||
| @@ -528,7 +536,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification] | |||
| user RKME statistical specification | |||
| Returns | |||
| @@ -538,7 +546,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| the second is the list of Learnware | |||
| both lists are sorted by mmd dist | |||
| """ | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_info_name) for learnware in learnware_list] | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] | |||
| mmd_dist_list = [] | |||
| for RKME in RKME_list: | |||
| mmd_dist = RKME.dist(user_rkme) | |||
| @@ -557,12 +565,12 @@ class EasyStatSearcher(BaseSearcher): | |||
| max_search_num: int = 5, | |||
| search_method: str = "greedy", | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| if "RKMETextSpecification" in user_info.stat_info: | |||
| self.stat_info_name = "RKMETextSpecification" | |||
| else: | |||
| self.stat_info_name = "RKMETableSpecification" | |||
| user_rkme = user_info.stat_info[self.stat_info_name] | |||
| learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | |||
| self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info) | |||
| if self.stat_spec_type is None: | |||
| raise KeyError("No supported stat specification is given in the user info") | |||
| user_rkme = user_info.stat_info[self.stat_spec_type] | |||
| learnware_list = self._filter_by_rkme_spec_metadata(learnware_list, user_rkme) | |||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| @@ -637,9 +645,8 @@ class EasySearcher(BaseSearcher): | |||
| if len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| elif "RKMETableSpecification" in user_info.stat_info: | |||
| return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| elif "RKMETextSpecification" in user_info.stat_info: | |||
| if parse_specification_type(stat_spec=user_info.stat_info) is not None: | |||
| return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| else: | |||
| return None, learnware_list, 0.0, None | |||
| @@ -0,0 +1,11 @@ | |||
| from ..specification import Specification | |||
| def parse_specification_type( | |||
| stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] | |||
| ): | |||
| stat_specs = stat_spec.stat_spec | |||
| for spec in spec_list: | |||
| if spec in stat_specs: | |||
| return spec | |||
| return None | |||
| @@ -1,15 +1,17 @@ | |||
| import torch | |||
| import numpy as np | |||
| from typing import List | |||
| from typing import List, Union | |||
| from cvxopt import matrix, solvers | |||
| from lightgbm import LGBMClassifier, early_stopping | |||
| from sklearn.metrics import accuracy_score | |||
| from learnware.learnware import Learnware | |||
| import learnware.specification as specification | |||
| from .base import BaseReuser | |||
| from ..market.utils import parse_specification_type | |||
| from ..learnware import Learnware | |||
| from ..specification import RKMETableSpecification, RKMETextSpecification | |||
| from ..specification.utils import generate_rkme_spec | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("job_selector_reuse") | |||
| @@ -32,7 +34,7 @@ class JobSelectorReuser(BaseReuser): | |||
| self.herding_num = herding_num | |||
| self.use_herding = use_herding | |||
| def predict(self, user_data: np.ndarray) -> np.ndarray: | |||
| def predict(self, user_data: Union[np.ndarray, List[str]]) -> np.ndarray: | |||
| """Give prediction for user data using baseline job-selector method | |||
| Parameters | |||
| @@ -41,12 +43,16 @@ class JobSelectorReuser(BaseReuser): | |||
| User's unlabeled raw data. | |||
| Returns | |||
| ------- | |||
| ------ | |||
| np.ndarray | |||
| Prediction given by job-selector method | |||
| """ | |||
| ori_user_data = user_data | |||
| raw_user_data = user_data | |||
| if isinstance(user_data[0], str): | |||
| stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) | |||
| assert ( | |||
| stat_spec_type == "RKMETextSpecification" | |||
| ), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." | |||
| user_data = RKMETextSpecification.get_sentence_embedding(user_data) | |||
| select_result = self.job_selector(user_data) | |||
| @@ -56,8 +62,8 @@ class JobSelectorReuser(BaseReuser): | |||
| for idx in range(len(self.learnware_list)): | |||
| data_idx_list = np.where(select_result == idx)[0] | |||
| if len(data_idx_list) > 0: | |||
| # pred_y = self.learnware_list[idx].predict(ori_user_data[data_idx_list]) | |||
| pred_y = self.learnware_list[idx].predict([ori_user_data[i] for i in data_idx_list]) | |||
| # pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list]) | |||
| pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list]) | |||
| if isinstance(pred_y, torch.Tensor): | |||
| pred_y = pred_y.detach().cpu().numpy() | |||
| # elif isinstance(pred_y, tf.Tensor): | |||
| @@ -91,14 +97,9 @@ class JobSelectorReuser(BaseReuser): | |||
| user_data_num = len(user_data) | |||
| return np.array([0] * user_data_num) | |||
| else: | |||
| ori_user_data = user_data | |||
| if isinstance(user_data[0], str): | |||
| user_data = RKMETextSpecification.get_sentence_embedding(user_data) | |||
| spec_name = "RKMETableSpecification" | |||
| if len(self.learnware_list) and "RKMETextSpecification" in self.learnware_list[0].specification.stat_spec: | |||
| spec_name = "RKMETextSpecification" | |||
| stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) | |||
| learnware_rkme_spec_list = [ | |||
| learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list | |||
| learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list | |||
| ] | |||
| if self.use_herding: | |||
| @@ -179,9 +180,7 @@ class JobSelectorReuser(BaseReuser): | |||
| Inner product matrix calculated from task_rkme_list. | |||
| """ | |||
| task_num = len(task_rkme_list) | |||
| if isinstance(user_data[0], str): | |||
| user_data = RKMETextSpecification.get_sentence_embedding(user_data) | |||
| user_rkme_spec = specification.utils.generate_rkme_spec(X=user_data, reduce=False) | |||
| user_rkme_spec = generate_rkme_spec(X=user_data, reduce=False) | |||
| K = task_rkme_matrix | |||
| v = np.array([user_rkme_spec.inner_prod(task_rkme) for task_rkme in task_rkme_list]) | |||
| @@ -1,8 +1,8 @@ | |||
| from sentence_transformers import SentenceTransformer | |||
| from ..table import RKMETableSpecification | |||
| import numpy as np | |||
| import os | |||
| import langdetect | |||
| import numpy as np | |||
| from sentence_transformers import SentenceTransformer | |||
| from ..table import RKMETableSpecification | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("RKMETextSpecification", "INFO") | |||
| @@ -18,7 +18,7 @@ import learnware.specification as specification | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Image"], "Type": "Class"}, | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| @@ -96,6 +96,10 @@ class TestAllWorkflow(unittest.TestCase): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = {"Dimension": 64} | |||
| semantic_spec["Input"].update( | |||
| {f"{i}": f"The value in the digit image with row is {i // 8} and col is {i % 8}." for i in range(64)} | |||
| ) | |||
| semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}} | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||