Browse Source

[MNT,DOC] Update Doc, support search learnware by id list

tags/v0.3.2
chenzx 2 years ago
parent
commit
6fd2906400
2 changed files with 84 additions and 30 deletions
  1. +3
    -3
      examples/example_market_db/example_db.py
  2. +81
    -27
      learnware/market/easy.py

+ 3
- 3
examples/example_market_db/example_db.py View File

@@ -172,7 +172,7 @@ def test_stat_search():

if __name__ == "__main__":
learnware_num = 5
# prepare_learnware(learnware_num)
# test_market()
prepare_learnware(learnware_num)
test_market()
test_stat_search()
# test_search_sementics()
test_search_sementics()

+ 81
- 27
learnware/market/easy.py View File

@@ -7,7 +7,7 @@ import pandas as pd
from typing import Tuple, Any, List, Union, Dict

from .base import BaseMarket, BaseUserInfo
from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db
from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db, clear_learnware_table

from ..learnware import Learnware, get_learnware_from_dirpath
from ..specification import RKMEStatSpecification, Specification
@@ -18,17 +18,30 @@ logger = get_module_logger("market", "INFO")


class EasyMarket(BaseMarket):
def __init__(self):
"""Initializing an empty market"""
def __init__(self, rebuild:bool = False):
"""Initialize Learnware Market.
Automatically reload from db if available.
Build an empty db otherwise.

Parameters
----------
rebuild : bool, optional
Clear current database if set to True, by default False
!!! Do NOT set to True unless highly necessary !!!
"""
self.learnware_list = {} # id: Learnware
self.learnware_zip_list = {}
self.learnware_folder_list = {}
self.count = 0
self.semantic_spec_list = C.semantic_specs
self.reload_market()
self.reload_market(rebuild=rebuild) # Automatically reload the market
logger.info("Market Initialized!")

def reload_market(self) -> bool:
def reload_market(self, rebuild:bool = False) -> bool:
if rebuild:
logger.warning("Warning! You are trying to clear current database!")
clear_learnware_table()

self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db()

def check_learnware(self, learnware: Learnware) -> bool:
@@ -61,15 +74,10 @@ class EasyMarket(BaseMarket):

Parameters
----------
model_path : str
zip_path : str
Filepath for learnware model, a zipped file.
stat_spec_path : str
Filepath for statistical specification, a '.npy' file.
How to pass parameters requires further discussion.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
desc : str
Brief desciption for new learnware.

Returns
-------
@@ -80,18 +88,10 @@ class EasyMarket(BaseMarket):
------
FileNotFoundError
file for model or statistical specification not found

"""
if not os.path.exists(zip_path):
raise FileNotFoundError("Model or Stat_spec NOT Found.")

"""
rkme_stat_spec = RKMEStatSpecification()
rkme_stat_spec.load(stat_spec_path)
stat_spec = {"RKMEStatSpecification": rkme_stat_spec}
specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec)
"""

logger.info("Get new learnware from %s" % (zip_path))
id = "%08d" % (self.count)
target_zip_dir = os.path.join(C.learnware_zip_pool_path, "%s.zip" % (id))
@@ -403,17 +403,71 @@ class EasyMarket(BaseMarket):
def get_semantic_spec_list(self) -> dict:
return self.semantic_spec_list

def get_learnware_by_ids(self, id: str):
if not id in self.learnware_list:
raise Exception("Target id not found in market")
def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Search learnware by id or list of ids.

Parameters
----------
ids : Union[str, List[str]]
Give a id or a list of ids
str: id of targer learware
List[str]: A list of ids of target learnwares

Returns
-------
Union[Learnware, List[Learnware]]
Return target learnware or list of target learnwares.
None for Learnware NOT Found.
"""
if isinstance(ids, list):
ret = []
for id in ids:
if id in self.learnware_list:
ret.append(self.learnware_list[id])
else:
logger.warning("Learnware ID '%s' NOT Found!"%(id))
ret.append(None)
return ret
else:
return self.learnware_list[id]
try:
return self.learnware_list[ids]
except:
logger.warning("Learnware ID '%s' NOT Found!"%(ids))
return None

def get_learnware_path_by_ids(self, id: str) -> str:
if not id in self.learnware_zip_list:
raise Exception("Target id not found in market")

def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Get Zipped Learnware file by id

Parameters
----------
ids : Union[str, List[str]]
Give a id or a list of ids
str: id of targer learware
List[str]: A list of ids of target learnwares


Returns
-------
Union[Learnware, List[Learnware]]
Return the path for target learnware or list of path.
None for Learnware NOT Found.
"""
if isinstance(ids, list):
ret = []
for id in ids:
if id in self.learnware_zip_list:
ret.append(self.learnware_zip_list[id])
else:
logger.warning("Learnware ID '%s' NOT Found!"%(id))
ret.append(None)
return ret
else:
return self.learnware_zip_list[id]
try:
return self.learnware_zip_list[ids]
except:
logger.warning("Learnware ID '%s' NOT Found!"%(ids))
return None

def __len__(self):
return len(self.learnware_list.keys())


Loading…
Cancel
Save