| @@ -63,7 +63,9 @@ def test_search(): | |||
| for i in range(10): | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec.load(f"./learnware_pool/svm{i}/spec.json") | |||
| user_info = BaseUserInfo(id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec}) | |||
| user_info = BaseUserInfo( | |||
| id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec} | |||
| ) | |||
| sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) | |||
| print(f"search result of user{i}:") | |||
| @@ -71,7 +73,7 @@ def test_search(): | |||
| print(f"dist: {dist}, learnware_id: {learnware.id}, learnware_name: {learnware.name}") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| if __name__ == "__main__": | |||
| test_market() | |||
| @@ -17,7 +17,8 @@ def prepare_learnware(): | |||
| clf.fit(data_X, data_y) | |||
| joblib.dump(clf, "./svm/svm.pkl") | |||
| spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1) | |||
| spec.save("./svm/spec.json") | |||
| @@ -1,2 +1,65 @@ | |||
| from .base import Learnware | |||
| from .reuse import BaseReuse | |||
| from .utils import get_stat_spec_from_config, get_model_from_config | |||
| from ..specification import RKMEStatSpecification, Specification | |||
| from ..utils import get_module_by_module_path | |||
| from ..logger import get_module_logger | |||
| from typing import Tuple | |||
| from .base import Learnware | |||
| logger = get_module_logger("learnware.learnware") | |||
| def get_learnware_from_config(id: int, file_config: dict, semantic_spec: dict) -> Learnware: | |||
| """Get the learnware object from config, and provide the manage interface tor Learnware class | |||
| Parameters | |||
| ---------- | |||
| id : int | |||
| The learnware id that is given by learnware market | |||
| file_config : dict | |||
| The learnware file config that demonstrates the name, model, and statistic specification config of learnware | |||
| semantic_spec : dict | |||
| The learnware semantice specifactions | |||
| Returns | |||
| ------- | |||
| Learnware | |||
| The contructed learnware object, return None if build failed | |||
| """ | |||
| learnware_config = { | |||
| "name": "None", | |||
| "model": { | |||
| "class_name": "Model", | |||
| "kwargs": {}, | |||
| }, | |||
| "stat_specifications": [ | |||
| { | |||
| "module_name": "learnware.specification", | |||
| "class_name": "RKMEStatSpecification", | |||
| "kwargs": {}, | |||
| }, | |||
| ], | |||
| } | |||
| if "name" in file_config: | |||
| learnware_config["name"] = file_config["name"] | |||
| if "model" in file_config: | |||
| learnware_config["model"].update(file_config["model"]) | |||
| if "stats_specifications" in file_config: | |||
| learnware_config["stat_specifications"] = file_config["stat_specifications"] | |||
| try: | |||
| learnware_spec = Specification() | |||
| for _stat_spec in learnware_config["stat_specifications"]: | |||
| stat_spac_name, stat_spec_inst = get_stat_spec_from_config(_stat_spec) | |||
| learnware_spec.update_stat_spec(**{stat_spac_name: stat_spec_inst}) | |||
| learnware_spec.upload_semantic_spec(semantic_spec) | |||
| learnware_model = get_model_from_config(learnware_config["model"]) | |||
| except Exception: | |||
| logger.warning(f"Load Learnware {id} failed!") | |||
| return None | |||
| return Learnware(id=id, name=learnware_config["name"], model=learnware_model, specification=learnware_spec) | |||
| @@ -8,42 +8,12 @@ from ..utils import get_module_by_module_path | |||
| class Learnware: | |||
| def __init__(self, id: str, name: str, model: Union[BaseModel, str], specification: Specification): | |||
| def __init__(self, id: str, name: str, model: BaseModel, specification: Specification): | |||
| self.id = id | |||
| self.name = name | |||
| self.model = self._import_model(model) | |||
| self.model = model | |||
| self.specification = specification | |||
| def _import_model(self, model: Union[BaseModel, str]) -> BaseModel: | |||
| """_summary_ | |||
| Parameters | |||
| ---------- | |||
| model : Union[BaseModel, dict] | |||
| - If isinstance(model, str), model is the path of the python file | |||
| - If isinstance(model, BaseModel), return model directly | |||
| Returns | |||
| ------- | |||
| BaseModel | |||
| The model that is given by user | |||
| Raises | |||
| ------ | |||
| TypeError | |||
| The type of model must be str or BaseModel, else raise error | |||
| """ | |||
| if isinstance(model, BaseModel): | |||
| return model | |||
| elif isinstance(model, str): | |||
| model_dict = { | |||
| "module_path": model, # path of python file | |||
| "class_name": "Model" # the name of class in python file, default is "Model", can be changed by yaml | |||
| } | |||
| # TODO: test yaml file, change model_dict["class_name"] | |||
| model_module = get_module_by_module_path(model_dict["module_path"]) | |||
| return getattr(model_module, model_dict["class_name"])() | |||
| else: | |||
| raise TypeError("model must be BaseModel or str") | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict(X) | |||
| @@ -0,0 +1,51 @@ | |||
| from .base import Learnware | |||
| from .reuse import BaseReuse | |||
| from typing import Tuple, Union | |||
| from .base import Learnware | |||
| from ..model import BaseModel | |||
| from ..specification import BaseStatSpecification | |||
| from ..utils import get_module_by_module_path | |||
| import learnware.specification as specification | |||
| def get_model_from_config(model: Union[BaseModel, dict]) -> BaseModel: | |||
| """_summary_ | |||
| Parameters | |||
| ---------- | |||
| model : Union[BaseModel, dict] | |||
| - If isinstance(model, dict), model is must be the following format: | |||
| model_dict = { | |||
| "module_path": str, # path of python file | |||
| "class_name": str, # the name of class in python file | |||
| } | |||
| - If isinstance(model, BaseModel), return model directly | |||
| Returns | |||
| ------- | |||
| BaseModel | |||
| The model that is given by user | |||
| Raises | |||
| ------ | |||
| TypeError | |||
| The type of model must be dict or BaseModel, else raise error | |||
| """ | |||
| if isinstance(model, BaseModel): | |||
| return model | |||
| elif isinstance(model, dict): | |||
| model_module = get_module_by_module_path(model["module_path"]) | |||
| return getattr(model_module, model["class_name"])(**model["kwargs"]) | |||
| else: | |||
| raise TypeError("model must be type of BaseModel or str") | |||
| def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: | |||
| stat_spec_module = get_module_by_module_path(stat_spec["module_path"]) | |||
| stat_spec_inst = getattr(stat_spec_module, stat_spec["class_name"])(**stat_spec["kwargs"]) | |||
| if not isinstance(stat_spec_inst, BaseStatSpecification): | |||
| raise TypeError( | |||
| f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}" | |||
| ) | |||
| return stat_spec["class_name"], stat_spec_inst | |||
| @@ -36,6 +36,7 @@ def init_empty_db(func): | |||
| return wrapper | |||
| # Clear Learnware Database | |||
| # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |||
| # !!!!! !!!!! | |||
| @@ -47,6 +48,7 @@ def clear_learnware_table(cur): | |||
| LOGGER.warning("!!! Drop Learnware Table !!!") | |||
| cur.execute("DROP TABLE LEARNWARE") | |||
| @init_empty_db | |||
| def add_learnware_to_db(id: str, name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, cur): | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| @@ -12,7 +12,7 @@ from ..specification import RKMEStatSpecification, Specification | |||
| from ..logger import get_module_logger | |||
| from ..config import C | |||
| LOGGER = get_module_logger("market", "INFO") | |||
| logger = get_module_logger("market", "INFO") | |||
| class EasyMarket(BaseMarket): | |||
| @@ -22,7 +22,7 @@ class EasyMarket(BaseMarket): | |||
| self.count = 0 | |||
| self.semantic_spec_list = C.semantic_specs | |||
| self.reload_market() | |||
| LOGGER.info("Market Initialized!") | |||
| logger.info("Market Initialized!") | |||
| def reload_market(self) -> bool: | |||
| self.learnware_list, self.count = load_market_from_db() | |||
| @@ -42,10 +42,11 @@ class EasyMarket(BaseMarket): | |||
| try: | |||
| spec_data = learnware.specification.stat_spec["RKME"].get_z() | |||
| pred_spec = learnware.predict(spec_data) | |||
| return True | |||
| except: | |||
| except Exception: | |||
| logger.warning(f"The learnware [{learnware.id}-{learnware.name}] is not avaliable!") | |||
| return False | |||
| return True | |||
| def add_learnware( | |||
| self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict | |||
| ) -> Tuple[str, bool]: | |||
| @@ -88,14 +89,18 @@ class EasyMarket(BaseMarket): | |||
| rkme_stat_spec.load(stat_spec_path) | |||
| stat_spec = {"RKME": rkme_stat_spec} | |||
| specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) | |||
| id = "%08d" % (self.count) | |||
| new_learnware = Learnware(id=id, name=learnware_name, model=model_path, specification=specification) | |||
| if self.check_learnware(new_learnware): | |||
| if self.check_learnware(new_learnware): | |||
| self.learnware_list[id] = new_learnware | |||
| self.count += 1 | |||
| add_learnware_to_db( | |||
| id, name=learnware_name, model_path=model_path, stat_spec_path=stat_spec_path, semantic_spec=semantic_spec | |||
| id, | |||
| name=learnware_name, | |||
| model_path=model_path, | |||
| stat_spec_path=stat_spec_path, | |||
| semantic_spec=semantic_spec, | |||
| ) | |||
| return id, True | |||
| else: | |||
| @@ -303,11 +308,13 @@ class EasyMarket(BaseMarket): | |||
| if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): | |||
| match_learnwares.append(learnware) | |||
| return match_learnwares | |||
| learnware_list = [self.learnware_list[key] for key in self.learnware_list] | |||
| return learnware_list | |||
| def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]: | |||
| def search_learnware( | |||
| self, user_info: BaseUserInfo, search_num=3 | |||
| ) -> Tuple[List[float], List[Learnware], List[Learnware]]: | |||
| """Search learnwares based on user_info | |||
| Parameters | |||
| @@ -331,7 +338,9 @@ class EasyMarket(BaseMarket): | |||
| else: | |||
| user_rkme = user_info.stat_info["RKME"] | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num) | |||
| weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture( | |||
| learnware_list, user_rkme, search_num | |||
| ) | |||
| return sorted_dist_list, single_learnware_list, mixture_learnware_list | |||
| def delete_learnware(self, id: str) -> bool: | |||
| @@ -29,8 +29,9 @@ class Specification: | |||
| def upload_semantic_spec(self, new_semantic_spec: dict): | |||
| self.semantic_spec = new_semantic_spec | |||
| def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification): | |||
| self.stat_spec[name] = new_stat_spec | |||
| def update_stat_spec(self, **kwargs): | |||
| for _k, _v in kwargs: | |||
| self.stat_spec[_k] = _v | |||
| def get_stat_spec_by_name(self, name: str): | |||
| return self.stat_spec.get(name, None) | |||
| @@ -41,6 +41,8 @@ REQUIRED = [ | |||
| # "mkl-service>=2.3.0", | |||
| "cvxopt>=1.3.0", | |||
| "tqdm>=4.65.0", | |||
| "scikit-learn>=1.2.2", | |||
| "joblib>=1.2.0", | |||
| ] | |||
| here = os.path.abspath(os.path.dirname(__file__)) | |||