| @@ -63,7 +63,9 @@ def test_search(): | |||||
| for i in range(10): | for i in range(10): | ||||
| user_spec = specification.rkme.RKMEStatSpecification() | user_spec = specification.rkme.RKMEStatSpecification() | ||||
| user_spec.load(f"./learnware_pool/svm{i}/spec.json") | user_spec.load(f"./learnware_pool/svm{i}/spec.json") | ||||
| user_info = BaseUserInfo(id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec}) | |||||
| user_info = BaseUserInfo( | |||||
| id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec} | |||||
| ) | |||||
| sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) | sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) | ||||
| print(f"search result of user{i}:") | print(f"search result of user{i}:") | ||||
| @@ -71,7 +73,7 @@ def test_search(): | |||||
| print(f"dist: {dist}, learnware_id: {learnware.id}, learnware_name: {learnware.name}") | print(f"dist: {dist}, learnware_id: {learnware.id}, learnware_name: {learnware.name}") | ||||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | ||||
| print(f"mixture_learnware: {mixture_id}\n") | print(f"mixture_learnware: {mixture_id}\n") | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_market() | test_market() | ||||
| @@ -17,7 +17,8 @@ def prepare_learnware(): | |||||
| clf.fit(data_X, data_y) | clf.fit(data_X, data_y) | ||||
| joblib.dump(clf, "./svm/svm.pkl") | joblib.dump(clf, "./svm/svm.pkl") | ||||
| spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||||
| spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1) | |||||
| spec.save("./svm/spec.json") | spec.save("./svm/spec.json") | ||||
| @@ -1,2 +1,65 @@ | |||||
| from .base import Learnware | from .base import Learnware | ||||
| from .reuse import BaseReuse | |||||
| from .utils import get_stat_spec_from_config, get_model_from_config | |||||
| from ..specification import RKMEStatSpecification, Specification | |||||
| from ..utils import get_module_by_module_path | |||||
| from ..logger import get_module_logger | |||||
| from typing import Tuple | |||||
| from .base import Learnware | |||||
| logger = get_module_logger("learnware.learnware") | |||||
| def get_learnware_from_config(id: int, file_config: dict, semantic_spec: dict) -> Learnware: | |||||
| """Get the learnware object from config, and provide the manage interface tor Learnware class | |||||
| Parameters | |||||
| ---------- | |||||
| id : int | |||||
| The learnware id that is given by learnware market | |||||
| file_config : dict | |||||
| The learnware file config that demonstrates the name, model, and statistic specification config of learnware | |||||
| semantic_spec : dict | |||||
| The learnware semantice specifactions | |||||
| Returns | |||||
| ------- | |||||
| Learnware | |||||
| The contructed learnware object, return None if build failed | |||||
| """ | |||||
| learnware_config = { | |||||
| "name": "None", | |||||
| "model": { | |||||
| "class_name": "Model", | |||||
| "kwargs": {}, | |||||
| }, | |||||
| "stat_specifications": [ | |||||
| { | |||||
| "module_name": "learnware.specification", | |||||
| "class_name": "RKMEStatSpecification", | |||||
| "kwargs": {}, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| if "name" in file_config: | |||||
| learnware_config["name"] = file_config["name"] | |||||
| if "model" in file_config: | |||||
| learnware_config["model"].update(file_config["model"]) | |||||
| if "stats_specifications" in file_config: | |||||
| learnware_config["stat_specifications"] = file_config["stat_specifications"] | |||||
| try: | |||||
| learnware_spec = Specification() | |||||
| for _stat_spec in learnware_config["stat_specifications"]: | |||||
| stat_spac_name, stat_spec_inst = get_stat_spec_from_config(_stat_spec) | |||||
| learnware_spec.update_stat_spec(**{stat_spac_name: stat_spec_inst}) | |||||
| learnware_spec.upload_semantic_spec(semantic_spec) | |||||
| learnware_model = get_model_from_config(learnware_config["model"]) | |||||
| except Exception: | |||||
| logger.warning(f"Load Learnware {id} failed!") | |||||
| return None | |||||
| return Learnware(id=id, name=learnware_config["name"], model=learnware_model, specification=learnware_spec) | |||||
| @@ -8,42 +8,12 @@ from ..utils import get_module_by_module_path | |||||
| class Learnware: | class Learnware: | ||||
| def __init__(self, id: str, name: str, model: Union[BaseModel, str], specification: Specification): | |||||
| def __init__(self, id: str, name: str, model: BaseModel, specification: Specification): | |||||
| self.id = id | self.id = id | ||||
| self.name = name | self.name = name | ||||
| self.model = self._import_model(model) | |||||
| self.model = model | |||||
| self.specification = specification | self.specification = specification | ||||
| def _import_model(self, model: Union[BaseModel, str]) -> BaseModel: | |||||
| """_summary_ | |||||
| Parameters | |||||
| ---------- | |||||
| model : Union[BaseModel, dict] | |||||
| - If isinstance(model, str), model is the path of the python file | |||||
| - If isinstance(model, BaseModel), return model directly | |||||
| Returns | |||||
| ------- | |||||
| BaseModel | |||||
| The model that is given by user | |||||
| Raises | |||||
| ------ | |||||
| TypeError | |||||
| The type of model must be str or BaseModel, else raise error | |||||
| """ | |||||
| if isinstance(model, BaseModel): | |||||
| return model | |||||
| elif isinstance(model, str): | |||||
| model_dict = { | |||||
| "module_path": model, # path of python file | |||||
| "class_name": "Model" # the name of class in python file, default is "Model", can be changed by yaml | |||||
| } | |||||
| # TODO: test yaml file, change model_dict["class_name"] | |||||
| model_module = get_module_by_module_path(model_dict["module_path"]) | |||||
| return getattr(model_module, model_dict["class_name"])() | |||||
| else: | |||||
| raise TypeError("model must be BaseModel or str") | |||||
| def predict(self, X: np.ndarray) -> np.ndarray: | def predict(self, X: np.ndarray) -> np.ndarray: | ||||
| return self.model.predict(X) | return self.model.predict(X) | ||||
| @@ -0,0 +1,51 @@ | |||||
| from .base import Learnware | |||||
| from .reuse import BaseReuse | |||||
| from typing import Tuple, Union | |||||
| from .base import Learnware | |||||
| from ..model import BaseModel | |||||
| from ..specification import BaseStatSpecification | |||||
| from ..utils import get_module_by_module_path | |||||
| import learnware.specification as specification | |||||
| def get_model_from_config(model: Union[BaseModel, dict]) -> BaseModel: | |||||
| """_summary_ | |||||
| Parameters | |||||
| ---------- | |||||
| model : Union[BaseModel, dict] | |||||
| - If isinstance(model, dict), model is must be the following format: | |||||
| model_dict = { | |||||
| "module_path": str, # path of python file | |||||
| "class_name": str, # the name of class in python file | |||||
| } | |||||
| - If isinstance(model, BaseModel), return model directly | |||||
| Returns | |||||
| ------- | |||||
| BaseModel | |||||
| The model that is given by user | |||||
| Raises | |||||
| ------ | |||||
| TypeError | |||||
| The type of model must be dict or BaseModel, else raise error | |||||
| """ | |||||
| if isinstance(model, BaseModel): | |||||
| return model | |||||
| elif isinstance(model, dict): | |||||
| model_module = get_module_by_module_path(model["module_path"]) | |||||
| return getattr(model_module, model["class_name"])(**model["kwargs"]) | |||||
| else: | |||||
| raise TypeError("model must be type of BaseModel or str") | |||||
| def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: | |||||
| stat_spec_module = get_module_by_module_path(stat_spec["module_path"]) | |||||
| stat_spec_inst = getattr(stat_spec_module, stat_spec["class_name"])(**stat_spec["kwargs"]) | |||||
| if not isinstance(stat_spec_inst, BaseStatSpecification): | |||||
| raise TypeError( | |||||
| f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}" | |||||
| ) | |||||
| return stat_spec["class_name"], stat_spec_inst | |||||
| @@ -36,6 +36,7 @@ def init_empty_db(func): | |||||
| return wrapper | return wrapper | ||||
| # Clear Learnware Database | # Clear Learnware Database | ||||
| # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | ||||
| # !!!!! !!!!! | # !!!!! !!!!! | ||||
| @@ -47,6 +48,7 @@ def clear_learnware_table(cur): | |||||
| LOGGER.warning("!!! Drop Learnware Table !!!") | LOGGER.warning("!!! Drop Learnware Table !!!") | ||||
| cur.execute("DROP TABLE LEARNWARE") | cur.execute("DROP TABLE LEARNWARE") | ||||
| @init_empty_db | @init_empty_db | ||||
| def add_learnware_to_db(id: str, name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, cur): | def add_learnware_to_db(id: str, name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, cur): | ||||
| semantic_spec_str = json.dumps(semantic_spec) | semantic_spec_str = json.dumps(semantic_spec) | ||||
| @@ -12,7 +12,7 @@ from ..specification import RKMEStatSpecification, Specification | |||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| from ..config import C | from ..config import C | ||||
| LOGGER = get_module_logger("market", "INFO") | |||||
| logger = get_module_logger("market", "INFO") | |||||
| class EasyMarket(BaseMarket): | class EasyMarket(BaseMarket): | ||||
| @@ -22,7 +22,7 @@ class EasyMarket(BaseMarket): | |||||
| self.count = 0 | self.count = 0 | ||||
| self.semantic_spec_list = C.semantic_specs | self.semantic_spec_list = C.semantic_specs | ||||
| self.reload_market() | self.reload_market() | ||||
| LOGGER.info("Market Initialized!") | |||||
| logger.info("Market Initialized!") | |||||
| def reload_market(self) -> bool: | def reload_market(self) -> bool: | ||||
| self.learnware_list, self.count = load_market_from_db() | self.learnware_list, self.count = load_market_from_db() | ||||
| @@ -42,10 +42,11 @@ class EasyMarket(BaseMarket): | |||||
| try: | try: | ||||
| spec_data = learnware.specification.stat_spec["RKME"].get_z() | spec_data = learnware.specification.stat_spec["RKME"].get_z() | ||||
| pred_spec = learnware.predict(spec_data) | pred_spec = learnware.predict(spec_data) | ||||
| return True | |||||
| except: | |||||
| except Exception: | |||||
| logger.warning(f"The learnware [{learnware.id}-{learnware.name}] is not avaliable!") | |||||
| return False | return False | ||||
| return True | |||||
| def add_learnware( | def add_learnware( | ||||
| self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict | self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict | ||||
| ) -> Tuple[str, bool]: | ) -> Tuple[str, bool]: | ||||
| @@ -88,14 +89,18 @@ class EasyMarket(BaseMarket): | |||||
| rkme_stat_spec.load(stat_spec_path) | rkme_stat_spec.load(stat_spec_path) | ||||
| stat_spec = {"RKME": rkme_stat_spec} | stat_spec = {"RKME": rkme_stat_spec} | ||||
| specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) | specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) | ||||
| id = "%08d" % (self.count) | id = "%08d" % (self.count) | ||||
| new_learnware = Learnware(id=id, name=learnware_name, model=model_path, specification=specification) | new_learnware = Learnware(id=id, name=learnware_name, model=model_path, specification=specification) | ||||
| if self.check_learnware(new_learnware): | |||||
| if self.check_learnware(new_learnware): | |||||
| self.learnware_list[id] = new_learnware | self.learnware_list[id] = new_learnware | ||||
| self.count += 1 | self.count += 1 | ||||
| add_learnware_to_db( | add_learnware_to_db( | ||||
| id, name=learnware_name, model_path=model_path, stat_spec_path=stat_spec_path, semantic_spec=semantic_spec | |||||
| id, | |||||
| name=learnware_name, | |||||
| model_path=model_path, | |||||
| stat_spec_path=stat_spec_path, | |||||
| semantic_spec=semantic_spec, | |||||
| ) | ) | ||||
| return id, True | return id, True | ||||
| else: | else: | ||||
| @@ -303,11 +308,13 @@ class EasyMarket(BaseMarket): | |||||
| if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): | if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): | ||||
| match_learnwares.append(learnware) | match_learnwares.append(learnware) | ||||
| return match_learnwares | return match_learnwares | ||||
| learnware_list = [self.learnware_list[key] for key in self.learnware_list] | learnware_list = [self.learnware_list[key] for key in self.learnware_list] | ||||
| return learnware_list | return learnware_list | ||||
| def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]: | |||||
| def search_learnware( | |||||
| self, user_info: BaseUserInfo, search_num=3 | |||||
| ) -> Tuple[List[float], List[Learnware], List[Learnware]]: | |||||
| """Search learnwares based on user_info | """Search learnwares based on user_info | ||||
| Parameters | Parameters | ||||
| @@ -331,7 +338,9 @@ class EasyMarket(BaseMarket): | |||||
| else: | else: | ||||
| user_rkme = user_info.stat_info["RKME"] | user_rkme = user_info.stat_info["RKME"] | ||||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | ||||
| weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num) | |||||
| weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture( | |||||
| learnware_list, user_rkme, search_num | |||||
| ) | |||||
| return sorted_dist_list, single_learnware_list, mixture_learnware_list | return sorted_dist_list, single_learnware_list, mixture_learnware_list | ||||
| def delete_learnware(self, id: str) -> bool: | def delete_learnware(self, id: str) -> bool: | ||||
| @@ -29,8 +29,9 @@ class Specification: | |||||
| def upload_semantic_spec(self, new_semantic_spec: dict): | def upload_semantic_spec(self, new_semantic_spec: dict): | ||||
| self.semantic_spec = new_semantic_spec | self.semantic_spec = new_semantic_spec | ||||
| def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification): | |||||
| self.stat_spec[name] = new_stat_spec | |||||
| def update_stat_spec(self, **kwargs): | |||||
| for _k, _v in kwargs: | |||||
| self.stat_spec[_k] = _v | |||||
| def get_stat_spec_by_name(self, name: str): | def get_stat_spec_by_name(self, name: str): | ||||
| return self.stat_spec.get(name, None) | return self.stat_spec.get(name, None) | ||||
| @@ -41,6 +41,8 @@ REQUIRED = [ | |||||
| # "mkl-service>=2.3.0", | # "mkl-service>=2.3.0", | ||||
| "cvxopt>=1.3.0", | "cvxopt>=1.3.0", | ||||
| "tqdm>=4.65.0", | "tqdm>=4.65.0", | ||||
| "scikit-learn>=1.2.2", | |||||
| "joblib>=1.2.0", | |||||
| ] | ] | ||||
| here = os.path.abspath(os.path.dirname(__file__)) | here = os.path.abspath(os.path.dirname(__file__)) | ||||