diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 0e6ae45..ef09b2f 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -172,7 +172,7 @@ def test_stat_search(): if __name__ == "__main__": learnware_num = 5 - # prepare_learnware(learnware_num) - # test_market() + prepare_learnware(learnware_num) + test_market() test_stat_search() - # test_search_sementics() + test_search_sementics() diff --git a/learnware/market/easy.py b/learnware/market/easy.py index fdf660a..98851d2 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -7,7 +7,7 @@ import pandas as pd from typing import Tuple, Any, List, Union, Dict from .base import BaseMarket, BaseUserInfo -from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db +from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db, clear_learnware_table from ..learnware import Learnware, get_learnware_from_dirpath from ..specification import RKMEStatSpecification, Specification @@ -18,17 +18,30 @@ logger = get_module_logger("market", "INFO") class EasyMarket(BaseMarket): - def __init__(self): - """Initializing an empty market""" + def __init__(self, rebuild:bool = False): + """Initialize Learnware Market. + Automatically reload from db if available. + Build an empty db otherwise. + + Parameters + ---------- + rebuild : bool, optional + Clear current database if set to True, by default False + !!! Do NOT set to True unless highly necessary !!! + """ self.learnware_list = {} # id: Learnware self.learnware_zip_list = {} self.learnware_folder_list = {} self.count = 0 self.semantic_spec_list = C.semantic_specs - self.reload_market() + self.reload_market(rebuild=rebuild) # Automatically reload the market logger.info("Market Initialized!") - def reload_market(self) -> bool: + def reload_market(self, rebuild:bool = False) -> bool: + if rebuild: + logger.warning("Warning! You are trying to clear current database!") + clear_learnware_table() + self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db() def check_learnware(self, learnware: Learnware) -> bool: @@ -61,15 +74,10 @@ class EasyMarket(BaseMarket): Parameters ---------- - model_path : str + zip_path : str Filepath for learnware model, a zipped file. - stat_spec_path : str - Filepath for statistical specification, a '.npy' file. - How to pass parameters requires further discussion. semantic_spec : dict semantic_spec for new learnware, in dictionary format. - desc : str - Brief desciption for new learnware. Returns ------- @@ -80,18 +88,10 @@ class EasyMarket(BaseMarket): ------ FileNotFoundError file for model or statistical specification not found - """ if not os.path.exists(zip_path): raise FileNotFoundError("Model or Stat_spec NOT Found.") - """ - rkme_stat_spec = RKMEStatSpecification() - rkme_stat_spec.load(stat_spec_path) - stat_spec = {"RKMEStatSpecification": rkme_stat_spec} - specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) - """ - logger.info("Get new learnware from %s" % (zip_path)) id = "%08d" % (self.count) target_zip_dir = os.path.join(C.learnware_zip_pool_path, "%s.zip" % (id)) @@ -403,17 +403,71 @@ class EasyMarket(BaseMarket): def get_semantic_spec_list(self) -> dict: return self.semantic_spec_list - def get_learnware_by_ids(self, id: str): - if not id in self.learnware_list: - raise Exception("Target id not found in market") + def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: + """Search learnware by id or list of ids. + + Parameters + ---------- + ids : Union[str, List[str]] + Give a id or a list of ids + str: id of targer learware + List[str]: A list of ids of target learnwares + + Returns + ------- + Union[Learnware, List[Learnware]] + Return target learnware or list of target learnwares. + None for Learnware NOT Found. + """ + if isinstance(ids, list): + ret = [] + for id in ids: + if id in self.learnware_list: + ret.append(self.learnware_list[id]) + else: + logger.warning("Learnware ID '%s' NOT Found!"%(id)) + ret.append(None) + return ret else: - return self.learnware_list[id] + try: + return self.learnware_list[ids] + except: + logger.warning("Learnware ID '%s' NOT Found!"%(ids)) + return None - def get_learnware_path_by_ids(self, id: str) -> str: - if not id in self.learnware_zip_list: - raise Exception("Target id not found in market") + + def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: + """Get Zipped Learnware file by id + + Parameters + ---------- + ids : Union[str, List[str]] + Give a id or a list of ids + str: id of targer learware + List[str]: A list of ids of target learnwares + + + Returns + ------- + Union[Learnware, List[Learnware]] + Return the path for target learnware or list of path. + None for Learnware NOT Found. + """ + if isinstance(ids, list): + ret = [] + for id in ids: + if id in self.learnware_zip_list: + ret.append(self.learnware_zip_list[id]) + else: + logger.warning("Learnware ID '%s' NOT Found!"%(id)) + ret.append(None) + return ret else: - return self.learnware_zip_list[id] + try: + return self.learnware_zip_list[ids] + except: + logger.warning("Learnware ID '%s' NOT Found!"%(ids)) + return None def __len__(self): return len(self.learnware_list.keys())