Browse Source

[ENH] Add import learnware methods

tags/v0.3.2
bxdd 2 years ago
parent
commit
79f4e4fd52
10 changed files with 151 additions and 50 deletions
  1. +4
    -2
      examples/example_market_db/example_db.py
  2. +2
    -1
      examples/examples2/example_learnware.py
  3. BIN
      images/(README)-pic_1680488105261.png
  4. +64
    -1
      learnware/learnware/__init__.py
  5. +2
    -32
      learnware/learnware/base.py
  6. +51
    -0
      learnware/learnware/utils.py
  7. +2
    -0
      learnware/market/database_ops.py
  8. +21
    -12
      learnware/market/easy.py
  9. +3
    -2
      learnware/specification/base.py
  10. +2
    -0
      setup.py

+ 4
- 2
examples/example_market_db/example_db.py View File

@@ -63,7 +63,9 @@ def test_search():
for i in range(10):
user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(f"./learnware_pool/svm{i}/spec.json")
user_info = BaseUserInfo(id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec})
user_info = BaseUserInfo(
id="user_0", semantic_spec={"desc": "test_user_number_0"}, stat_info={"RKME": user_spec}
)
sorted_dist_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)

print(f"search result of user{i}:")
@@ -71,7 +73,7 @@ def test_search():
print(f"dist: {dist}, learnware_id: {learnware.id}, learnware_name: {learnware.name}")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")

if __name__ == "__main__":
test_market()


+ 2
- 1
examples/examples2/example_learnware.py View File

@@ -17,7 +17,8 @@ def prepare_learnware():
clf.fit(data_X, data_y)
joblib.dump(clf, "./svm/svm.pkl")

spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1)

spec.save("./svm/spec.json")




BIN
images/(README)-pic_1680488105261.png View File

Before After
Width: 500  |  Height: 378  |  Size: 27 kB

+ 64
- 1
learnware/learnware/__init__.py View File

@@ -1,2 +1,65 @@
from .base import Learnware
from .reuse import BaseReuse
from .utils import get_stat_spec_from_config, get_model_from_config
from ..specification import RKMEStatSpecification, Specification
from ..utils import get_module_by_module_path
from ..logger import get_module_logger

from typing import Tuple

from .base import Learnware

logger = get_module_logger("learnware.learnware")


def get_learnware_from_config(id: int, file_config: dict, semantic_spec: dict) -> Learnware:
"""Get the learnware object from config, and provide the manage interface tor Learnware class

Parameters
----------
id : int
The learnware id that is given by learnware market
file_config : dict
The learnware file config that demonstrates the name, model, and statistic specification config of learnware
semantic_spec : dict
The learnware semantice specifactions

Returns
-------
Learnware
The contructed learnware object, return None if build failed
"""
learnware_config = {
"name": "None",
"model": {
"class_name": "Model",
"kwargs": {},
},
"stat_specifications": [
{
"module_name": "learnware.specification",
"class_name": "RKMEStatSpecification",
"kwargs": {},
},
],
}
if "name" in file_config:
learnware_config["name"] = file_config["name"]
if "model" in file_config:
learnware_config["model"].update(file_config["model"])
if "stats_specifications" in file_config:
learnware_config["stat_specifications"] = file_config["stat_specifications"]

try:
learnware_spec = Specification()
for _stat_spec in learnware_config["stat_specifications"]:
stat_spac_name, stat_spec_inst = get_stat_spec_from_config(_stat_spec)
learnware_spec.update_stat_spec(**{stat_spac_name: stat_spec_inst})

learnware_spec.upload_semantic_spec(semantic_spec)
learnware_model = get_model_from_config(learnware_config["model"])

except Exception:
logger.warning(f"Load Learnware {id} failed!")
return None

return Learnware(id=id, name=learnware_config["name"], model=learnware_model, specification=learnware_spec)

+ 2
- 32
learnware/learnware/base.py View File

@@ -8,42 +8,12 @@ from ..utils import get_module_by_module_path


class Learnware:
def __init__(self, id: str, name: str, model: Union[BaseModel, str], specification: Specification):
def __init__(self, id: str, name: str, model: BaseModel, specification: Specification):
self.id = id
self.name = name
self.model = self._import_model(model)
self.model = model
self.specification = specification

def _import_model(self, model: Union[BaseModel, str]) -> BaseModel:
"""_summary_

Parameters
----------
model : Union[BaseModel, dict]
- If isinstance(model, str), model is the path of the python file
- If isinstance(model, BaseModel), return model directly
Returns
-------
BaseModel
The model that is given by user
Raises
------
TypeError
The type of model must be str or BaseModel, else raise error
"""
if isinstance(model, BaseModel):
return model
elif isinstance(model, str):
model_dict = {
"module_path": model, # path of python file
"class_name": "Model" # the name of class in python file, default is "Model", can be changed by yaml
}
# TODO: test yaml file, change model_dict["class_name"]
model_module = get_module_by_module_path(model_dict["module_path"])
return getattr(model_module, model_dict["class_name"])()
else:
raise TypeError("model must be BaseModel or str")

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)



+ 51
- 0
learnware/learnware/utils.py View File

@@ -0,0 +1,51 @@
from .base import Learnware
from .reuse import BaseReuse


from typing import Tuple, Union

from .base import Learnware
from ..model import BaseModel
from ..specification import BaseStatSpecification
from ..utils import get_module_by_module_path
import learnware.specification as specification


def get_model_from_config(model: Union[BaseModel, dict]) -> BaseModel:
"""_summary_

Parameters
----------
model : Union[BaseModel, dict]
- If isinstance(model, dict), model is must be the following format:
model_dict = {
"module_path": str, # path of python file
"class_name": str, # the name of class in python file
}
- If isinstance(model, BaseModel), return model directly
Returns
-------
BaseModel
The model that is given by user
Raises
------
TypeError
The type of model must be dict or BaseModel, else raise error
"""
if isinstance(model, BaseModel):
return model
elif isinstance(model, dict):
model_module = get_module_by_module_path(model["module_path"])
return getattr(model_module, model["class_name"])(**model["kwargs"])
else:
raise TypeError("model must be type of BaseModel or str")


def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification:
stat_spec_module = get_module_by_module_path(stat_spec["module_path"])
stat_spec_inst = getattr(stat_spec_module, stat_spec["class_name"])(**stat_spec["kwargs"])
if not isinstance(stat_spec_inst, BaseStatSpecification):
raise TypeError(
f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}"
)
return stat_spec["class_name"], stat_spec_inst

+ 2
- 0
learnware/market/database_ops.py View File

@@ -36,6 +36,7 @@ def init_empty_db(func):

return wrapper


# Clear Learnware Database
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!! !!!!!
@@ -47,6 +48,7 @@ def clear_learnware_table(cur):
LOGGER.warning("!!! Drop Learnware Table !!!")
cur.execute("DROP TABLE LEARNWARE")


@init_empty_db
def add_learnware_to_db(id: str, name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, cur):
semantic_spec_str = json.dumps(semantic_spec)


+ 21
- 12
learnware/market/easy.py View File

@@ -12,7 +12,7 @@ from ..specification import RKMEStatSpecification, Specification
from ..logger import get_module_logger
from ..config import C

LOGGER = get_module_logger("market", "INFO")
logger = get_module_logger("market", "INFO")


class EasyMarket(BaseMarket):
@@ -22,7 +22,7 @@ class EasyMarket(BaseMarket):
self.count = 0
self.semantic_spec_list = C.semantic_specs
self.reload_market()
LOGGER.info("Market Initialized!")
logger.info("Market Initialized!")

def reload_market(self) -> bool:
self.learnware_list, self.count = load_market_from_db()
@@ -42,10 +42,11 @@ class EasyMarket(BaseMarket):
try:
spec_data = learnware.specification.stat_spec["RKME"].get_z()
pred_spec = learnware.predict(spec_data)
return True
except:
except Exception:
logger.warning(f"The learnware [{learnware.id}-{learnware.name}] is not avaliable!")
return False
return True

def add_learnware(
self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict
) -> Tuple[str, bool]:
@@ -88,14 +89,18 @@ class EasyMarket(BaseMarket):
rkme_stat_spec.load(stat_spec_path)
stat_spec = {"RKME": rkme_stat_spec}
specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec)
id = "%08d" % (self.count)
new_learnware = Learnware(id=id, name=learnware_name, model=model_path, specification=specification)
if self.check_learnware(new_learnware):
if self.check_learnware(new_learnware):
self.learnware_list[id] = new_learnware
self.count += 1
add_learnware_to_db(
id, name=learnware_name, model_path=model_path, stat_spec_path=stat_spec_path, semantic_spec=semantic_spec
id,
name=learnware_name,
model_path=model_path,
stat_spec_path=stat_spec_path,
semantic_spec=semantic_spec,
)
return id, True
else:
@@ -303,11 +308,13 @@ class EasyMarket(BaseMarket):
if match_semantic_spec(learnware_semantic_spec, user_semantic_spec):
match_learnwares.append(learnware)
return match_learnwares
learnware_list = [self.learnware_list[key] for key in self.learnware_list]
return learnware_list
def search_learnware(self, user_info: BaseUserInfo, search_num=3) -> Tuple[List[float], List[Learnware], List[Learnware]]:

def search_learnware(
self, user_info: BaseUserInfo, search_num=3
) -> Tuple[List[float], List[Learnware], List[Learnware]]:
"""Search learnwares based on user_info

Parameters
@@ -331,7 +338,9 @@ class EasyMarket(BaseMarket):
else:
user_rkme = user_info.stat_info["RKME"]
sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(learnware_list, user_rkme, search_num)
weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture(
learnware_list, user_rkme, search_num
)
return sorted_dist_list, single_learnware_list, mixture_learnware_list

def delete_learnware(self, id: str) -> bool:


+ 3
- 2
learnware/specification/base.py View File

@@ -29,8 +29,9 @@ class Specification:
def upload_semantic_spec(self, new_semantic_spec: dict):
self.semantic_spec = new_semantic_spec

def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification):
self.stat_spec[name] = new_stat_spec
def update_stat_spec(self, **kwargs):
for _k, _v in kwargs:
self.stat_spec[_k] = _v

def get_stat_spec_by_name(self, name: str):
return self.stat_spec.get(name, None)

+ 2
- 0
setup.py View File

@@ -41,6 +41,8 @@ REQUIRED = [
# "mkl-service>=2.3.0",
"cvxopt>=1.3.0",
"tqdm>=4.65.0",
"scikit-learn>=1.2.2",
"joblib>=1.2.0",
]

here = os.path.abspath(os.path.dirname(__file__))


Loading…
Cancel
Save