diff --git a/examples/examples1/example_rkme.py b/examples/examples1/example_rkme.py index 33c8819..f7842b3 100644 --- a/examples/examples1/example_rkme.py +++ b/examples/examples1/example_rkme.py @@ -7,7 +7,7 @@ if __name__ == "__main__": for i in range(10): data_X[i, i] = np.nan spec1 = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=-1) - spec2 = specification.rkme.RKMESpecification() + spec2 = specification.rkme.RKMEStatSpecification() spec1.generate_stat_spec_from_data(data_X) spec1.save("spec.json") diff --git a/learnware/learnware/base.py b/learnware/learnware/base.py index 644d4c4..3be20b9 100644 --- a/learnware/learnware/base.py +++ b/learnware/learnware/base.py @@ -8,12 +8,11 @@ from ..utils import get_module_by_module_path class Learnware: - def __init__(self, id: str, name: str, model: Union[BaseModel, dict], specification: Specification, desc: str): + def __init__(self, id: str, name: str, model: Union[BaseModel, dict], specification: Specification): self.id = id self.name = name self.model = self._import_model(model) self.specification = specification - self.desc = desc def _import_model(self, model: Union[BaseModel, dict]) -> BaseModel: """_summary_ diff --git a/learnware/market/serial.py b/learnware/market/serial.py index 0a9cf19..ded9027 100644 --- a/learnware/market/serial.py +++ b/learnware/market/serial.py @@ -5,7 +5,7 @@ from typing import Tuple, Any, List, Union, Dict from .base import BaseMarket, BaseUserInfo from ..learnware import Learnware -from ..specification import RKMEStatSpecification +from ..specification import RKMEStatSpecification, Specification class SerialMarket(BaseMarket): @@ -104,11 +104,18 @@ class SerialMarket(BaseMarket): if (not os.path.exists(model_path)) or (not os.path.exists(stat_spec_path)): raise FileNotFoundError("Model or Stat_spec NOT Found.") - id = "%08d" % (self.count) - stat_spec = RKMEStatSpecification() - stat_spec_path.load(stat_spec_path) - - return str(self.count), True + id = "%08d"%(self.count) + rkme_stat_spec = RKMEStatSpecification() + rkme_stat_spec.load(stat_spec_path) + specification = Specification(semantic_spec=semantic_spec) + specification.upload_stat_spec("RKME", rkme_stat_spec) + model_dict = {"model_path":model_path, "class_name":"BaseModel"} + new_learnware = Learnware(id=id, name=learnware_name, + model=model_dict, specification=specification) + self.learnware_list[id] = new_learnware + self.count += 1 + + return id, True def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: def search_by_semantic_spec(): diff --git a/learnware/specification/base.py b/learnware/specification/base.py index f223ffd..3b6df74 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -16,7 +16,7 @@ class BaseStatSpecification: class Specification: - def __init__(self, semantic_spec=None): + def __init__(self, semantic_spec:dict=None): self.semantic_spec = semantic_spec self.stat_spec = {} # stat_spec should be dict @@ -25,6 +25,9 @@ class Specification: def get_semantic_spec(self): return self.semantic_spec + + def upload_semantic_spec(self, new_semantic_spec: dict): + self.semantic_spec = new_semantic_spec def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification): self.stat_spec[name] = new_stat_spec