diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index cc9e544..7ac2303 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -63,7 +63,9 @@ def test_search(): for i in range(10): user_spec = specification.rkme.RKMEStatSpecification() user_spec.load(f"./learnware_pool/svm{i}/spec.json") - user_info = BaseUserInfo(id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec}) + user_info = BaseUserInfo( + id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec} + ) sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) print(f"search result of user{i}:") @@ -71,7 +73,7 @@ def test_search(): print(f"dist: {dist}, learnware_id: {learnware.id}, learnware_name: {learnware.name}") mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) print(f"mixture_learnware: {mixture_id}\n") - + if __name__ == "__main__": test_market() diff --git a/examples/examples2/example_learnware.py b/examples/examples2/example_learnware.py index 811ef2a..6ea7065 100644 --- a/examples/examples2/example_learnware.py +++ b/examples/examples2/example_learnware.py @@ -17,7 +17,8 @@ def prepare_learnware(): clf.fit(data_X, data_y) joblib.dump(clf, "./svm/svm.pkl") - spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) + spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1) + spec.save("./svm/spec.json") diff --git a/images/(README)-pic_1680488105261.png b/images/(README)-pic_1680488105261.png deleted file mode 100644 index c3368de..0000000 Binary files a/images/(README)-pic_1680488105261.png and /dev/null differ diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index 9bedb2a..9a2de34 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -1,2 +1,65 @@ from .base import Learnware -from .reuse import BaseReuse +from .utils import get_stat_spec_from_config, get_model_from_config +from ..specification import RKMEStatSpecification, Specification +from ..utils import get_module_by_module_path +from ..logger import get_module_logger + +from typing import Tuple + +from .base import Learnware + +logger = get_module_logger("learnware.learnware") + + +def get_learnware_from_config(id: int, file_config: dict, semantic_spec: dict) -> Learnware: + """Get the learnware object from config, and provide the manage interface tor Learnware class + + Parameters + ---------- + id : int + The learnware id that is given by learnware market + file_config : dict + The learnware file config that demonstrates the name, model, and statistic specification config of learnware + semantic_spec : dict + The learnware semantice specifactions + + Returns + ------- + Learnware + The contructed learnware object, return None if build failed + """ + learnware_config = { + "name": "None", + "model": { + "class_name": "Model", + "kwargs": {}, + }, + "stat_specifications": [ + { + "module_name": "learnware.specification", + "class_name": "RKMEStatSpecification", + "kwargs": {}, + }, + ], + } + if "name" in file_config: + learnware_config["name"] = file_config["name"] + if "model" in file_config: + learnware_config["model"].update(file_config["model"]) + if "stats_specifications" in file_config: + learnware_config["stat_specifications"] = file_config["stat_specifications"] + + try: + learnware_spec = Specification() + for _stat_spec in learnware_config["stat_specifications"]: + stat_spac_name, stat_spec_inst = get_stat_spec_from_config(_stat_spec) + learnware_spec.update_stat_spec(**{stat_spac_name: stat_spec_inst}) + + learnware_spec.upload_semantic_spec(semantic_spec) + learnware_model = get_model_from_config(learnware_config["model"]) + + except Exception: + logger.warning(f"Load Learnware {id} failed!") + return None + + return Learnware(id=id, name=learnware_config["name"], model=learnware_model, specification=learnware_spec) diff --git a/learnware/learnware/base.py b/learnware/learnware/base.py index 8578f91..c696fc2 100644 --- a/learnware/learnware/base.py +++ b/learnware/learnware/base.py @@ -8,42 +8,12 @@ from ..utils import get_module_by_module_path class Learnware: - def __init__(self, id: str, name: str, model: Union[BaseModel, str], specification: Specification): + def __init__(self, id: str, name: str, model: BaseModel, specification: Specification): self.id = id self.name = name - self.model = self._import_model(model) + self.model = model self.specification = specification - def _import_model(self, model: Union[BaseModel, str]) -> BaseModel: - """_summary_ - - Parameters - ---------- - model : Union[BaseModel, dict] - - If isinstance(model, str), model is the path of the python file - - If isinstance(model, BaseModel), return model directly - Returns - ------- - BaseModel - The model that is given by user - Raises - ------ - TypeError - The type of model must be str or BaseModel, else raise error - """ - if isinstance(model, BaseModel): - return model - elif isinstance(model, str): - model_dict = { - "module_path": model, # path of python file - "class_name": "Model" # the name of class in python file, default is "Model", can be changed by yaml - } - # TODO: test yaml file, change model_dict["class_name"] - model_module = get_module_by_module_path(model_dict["module_path"]) - return getattr(model_module, model_dict["class_name"])() - else: - raise TypeError("model must be BaseModel or str") - def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) diff --git a/learnware/learnware/utils.py b/learnware/learnware/utils.py new file mode 100644 index 0000000..4dacc33 --- /dev/null +++ b/learnware/learnware/utils.py @@ -0,0 +1,51 @@ +from .base import Learnware +from .reuse import BaseReuse + + +from typing import Tuple, Union + +from .base import Learnware +from ..model import BaseModel +from ..specification import BaseStatSpecification +from ..utils import get_module_by_module_path +import learnware.specification as specification + + +def get_model_from_config(model: Union[BaseModel, dict]) -> BaseModel: + """_summary_ + + Parameters + ---------- + model : Union[BaseModel, dict] + - If isinstance(model, dict), model is must be the following format: + model_dict = { + "module_path": str, # path of python file + "class_name": str, # the name of class in python file + } + - If isinstance(model, BaseModel), return model directly + Returns + ------- + BaseModel + The model that is given by user + Raises + ------ + TypeError + The type of model must be dict or BaseModel, else raise error + """ + if isinstance(model, BaseModel): + return model + elif isinstance(model, dict): + model_module = get_module_by_module_path(model["module_path"]) + return getattr(model_module, model["class_name"])(**model["kwargs"]) + else: + raise TypeError("model must be type of BaseModel or str") + + +def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: + stat_spec_module = get_module_by_module_path(stat_spec["module_path"]) + stat_spec_inst = getattr(stat_spec_module, stat_spec["class_name"])(**stat_spec["kwargs"]) + if not isinstance(stat_spec_inst, BaseStatSpecification): + raise TypeError( + f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}" + ) + return stat_spec["class_name"], stat_spec_inst diff --git a/learnware/market/database_ops.py b/learnware/market/database_ops.py index 03009af..e7aa6a8 100644 --- a/learnware/market/database_ops.py +++ b/learnware/market/database_ops.py @@ -36,6 +36,7 @@ def init_empty_db(func): return wrapper + # Clear Learnware Database # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # !!!!! !!!!! @@ -47,6 +48,7 @@ def clear_learnware_table(cur): LOGGER.warning("!!! Drop Learnware Table !!!") cur.execute("DROP TABLE LEARNWARE") + @init_empty_db def add_learnware_to_db(id: str, name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, cur): semantic_spec_str = json.dumps(semantic_spec) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index a752081..1ecf1b3 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -12,7 +12,7 @@ from ..specification import RKMEStatSpecification, Specification from ..logger import get_module_logger from ..config import C -LOGGER = get_module_logger("market", "INFO") +logger = get_module_logger("market", "INFO") class EasyMarket(BaseMarket): @@ -22,7 +22,7 @@ class EasyMarket(BaseMarket): self.count = 0 self.semantic_spec_list = C.semantic_specs self.reload_market() - LOGGER.info("Market Initialized!") + logger.info("Market Initialized!") def reload_market(self) -> bool: self.learnware_list, self.count = load_market_from_db() @@ -42,10 +42,11 @@ class EasyMarket(BaseMarket): try: spec_data = learnware.specification.stat_spec["RKME"].get_z() pred_spec = learnware.predict(spec_data) - return True - except: + except Exception: + logger.warning(f"The learnware [{learnware.id}-{learnware.name}] is not avaliable!") return False - + return True + def add_learnware( self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict ) -> Tuple[str, bool]: @@ -88,14 +89,18 @@ class EasyMarket(BaseMarket): rkme_stat_spec.load(stat_spec_path) stat_spec = {"RKME": rkme_stat_spec} specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) - + id = "%08d" % (self.count) new_learnware = Learnware(id=id, name=learnware_name, model=model_path, specification=specification) - if self.check_learnware(new_learnware): + if self.check_learnware(new_learnware): self.learnware_list[id] = new_learnware self.count += 1 add_learnware_to_db( - id, name=learnware_name, model_path=model_path, stat_spec_path=stat_spec_path, semantic_spec=semantic_spec + id, + name=learnware_name, + model_path=model_path, + stat_spec_path=stat_spec_path, + semantic_spec=semantic_spec, ) return id, True else: @@ -303,11 +308,13 @@ class EasyMarket(BaseMarket): if match_semantic_spec(learnware_semantic_spec, user_semantic_spec): match_learnwares.append(learnware) return match_learnwares - + learnware_list = [self.learnware_list[key] for key in self.learnware_list] return learnware_list - - def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]: + + def search_learnware( + self, user_info: BaseUserInfo, search_num=3 + ) -> Tuple[List[float], List[Learnware], List[Learnware]]: """Search learnwares based on user_info Parameters @@ -331,7 +338,9 @@ class EasyMarket(BaseMarket): else: user_rkme = user_info.stat_info["RKME"] sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) - weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num) + weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture( + learnware_list, user_rkme, search_num + ) return sorted_dist_list, single_learnware_list, mixture_learnware_list def delete_learnware(self, id: str) -> bool: diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 12b61c6..53e9c9f 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -29,8 +29,9 @@ class Specification: def upload_semantic_spec(self, new_semantic_spec: dict): self.semantic_spec = new_semantic_spec - def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification): - self.stat_spec[name] = new_stat_spec + def update_stat_spec(self, **kwargs): + for _k, _v in kwargs: + self.stat_spec[_k] = _v def get_stat_spec_by_name(self, name: str): return self.stat_spec.get(name, None) diff --git a/setup.py b/setup.py index 03a356c..99daa89 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,8 @@ REQUIRED = [ # "mkl-service>=2.3.0", "cvxopt>=1.3.0", "tqdm>=4.65.0", + "scikit-learn>=1.2.2", + "joblib>=1.2.0", ] here = os.path.abspath(os.path.dirname(__file__))