Browse Source

[ENH] Implement add learnware for market

tags/v0.3.2
chenzx 3 years ago
parent
commit
8baedbdfad
4 changed files with 19 additions and 10 deletions
  1. +1
    -1
      examples/examples1/example_rkme.py
  2. +1
    -2
      learnware/learnware/base.py
  3. +13
    -6
      learnware/market/serial.py
  4. +4
    -1
      learnware/specification/base.py

+ 1
- 1
examples/examples1/example_rkme.py View File

@@ -7,7 +7,7 @@ if __name__ == "__main__":
for i in range(10):
data_X[i, i] = np.nan
spec1 = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=-1)
spec2 = specification.rkme.RKMESpecification()
spec2 = specification.rkme.RKMEStatSpecification()
spec1.generate_stat_spec_from_data(data_X)
spec1.save("spec.json")



+ 1
- 2
learnware/learnware/base.py View File

@@ -8,12 +8,11 @@ from ..utils import get_module_by_module_path


class Learnware:
def __init__(self, id: str, name: str, model: Union[BaseModel, dict], specification: Specification, desc: str):
def __init__(self, id: str, name: str, model: Union[BaseModel, dict], specification: Specification):
self.id = id
self.name = name
self.model = self._import_model(model)
self.specification = specification
self.desc = desc

def _import_model(self, model: Union[BaseModel, dict]) -> BaseModel:
"""_summary_


+ 13
- 6
learnware/market/serial.py View File

@@ -5,7 +5,7 @@ from typing import Tuple, Any, List, Union, Dict

from .base import BaseMarket, BaseUserInfo
from ..learnware import Learnware
from ..specification import RKMEStatSpecification
from ..specification import RKMEStatSpecification, Specification


class SerialMarket(BaseMarket):
@@ -104,11 +104,18 @@ class SerialMarket(BaseMarket):
if (not os.path.exists(model_path)) or (not os.path.exists(stat_spec_path)):
raise FileNotFoundError("Model or Stat_spec NOT Found.")

id = "%08d" % (self.count)
stat_spec = RKMEStatSpecification()
stat_spec_path.load(stat_spec_path)

return str(self.count), True
id = "%08d"%(self.count)
rkme_stat_spec = RKMEStatSpecification()
rkme_stat_spec.load(stat_spec_path)
specification = Specification(semantic_spec=semantic_spec)
specification.upload_stat_spec("RKME", rkme_stat_spec)
model_dict = {"model_path":model_path, "class_name":"BaseModel"}
new_learnware = Learnware(id=id, name=learnware_name,
model=model_dict, specification=specification)
self.learnware_list[id] = new_learnware
self.count += 1

return id, True

def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]:
def search_by_semantic_spec():


+ 4
- 1
learnware/specification/base.py View File

@@ -16,7 +16,7 @@ class BaseStatSpecification:


class Specification:
def __init__(self, semantic_spec=None):
def __init__(self, semantic_spec:dict=None):
self.semantic_spec = semantic_spec
self.stat_spec = {} # stat_spec should be dict

@@ -25,6 +25,9 @@ class Specification:

def get_semantic_spec(self):
return self.semantic_spec
def upload_semantic_spec(self, new_semantic_spec: dict):
self.semantic_spec = new_semantic_spec

def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification):
self.stat_spec[name] = new_stat_spec


Loading…
Cancel
Save