| @@ -2,7 +2,7 @@ from .anchor import AnchoredUserInfo, AnchoredOrganizer | |||
| from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | |||
| from .evolve_anchor import EvolveAnchoredOrganizer | |||
| from .evolve import EvolvedOrganizer | |||
| from .easy2 import EasyChecker, EasyOrganizer, EasySearcher | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker | |||
| from .hetergeneous import HeterogeneousOrganizer, MappingFunction | |||
| from .easy import EasyMarket | |||
| @@ -1,11 +1,12 @@ | |||
| import os | |||
| import torch | |||
| import tempfile | |||
| import traceback | |||
| import numpy as np | |||
| from typing import Tuple, Any, List, Union | |||
| from ..learnware import Learnware | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("market_base", "INFO") | |||
| @@ -51,27 +52,57 @@ class LearnwareMarket: | |||
| self, | |||
| market_id: str = None, | |||
| organizer: "BaseOrganizer" = None, | |||
| checker: "BaseChecker" = None, | |||
| searcher: "BaseSearcher" = None, | |||
| checker_list: List["BaseChecker"] = None, | |||
| rebuild=False, | |||
| ): | |||
| self.market_id = market_id | |||
| self.learnware_organizer = BaseOrganizer() if organizer is None else organizer | |||
| self.learnware_checker = BaseChecker() if checker is None else checker | |||
| self.learnware_checker.reset(organizer=self.learnware_organizer) | |||
| self.learnware_organizer.reset(market_id=market_id, checker=self.learnware_checker) | |||
| self.learnware_organizer.reset(market_id=market_id) | |||
| self.learnware_organizer.reload_market(rebuild=rebuild) | |||
| self.learnware_searcher = BaseSearcher() if searcher is None else searcher | |||
| self.learnware_searcher.reset(organizer=self.learnware_organizer) | |||
| if checker_list is None: | |||
| self.learnware_checker = {"BaseChecker": BaseChecker()} | |||
| else: | |||
| self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list} | |||
| for name, checker in self.learnware_checker.items(): | |||
| checker.reset(organizer=self.learnware_organizer) | |||
| def reload_market(self, **kwargs) -> bool: | |||
| self.learnware_organizer.reload_market(**kwargs) | |||
| def check_learnware(self, learnware: Learnware, **kwargs) -> bool: | |||
| return self.learnware_checker(learnware, **kwargs) | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, **kwargs) -> Tuple[str, bool]: | |||
| return self.learnware_organizer.add_learnware(zip_path, semantic_spec, **kwargs) | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| try: | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| pending_learnware = get_learnware_from_dirpath( | |||
| id="pending", semantic_spec=semantic_specification, learnware_dirpath=tempdir | |||
| ) | |||
| final_status = BaseChecker.INVALID_LEARNWARE | |||
| checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names | |||
| for name in checker_names: | |||
| checker = self.learnware_checker[name] | |||
| check_status = checker(pending_learnware) | |||
| final_status = max(final_status, check_status) | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| return final_status | |||
| except Exception as err: | |||
| logger.warning(f"Check learnware failed! Due to {err}.") | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> Tuple[str, bool]: | |||
| check_status = self.check_learnware(zip_path, semantic_spec, checker_names) | |||
| return self.learnware_organizer.add_learnware(zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) | |||
| def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: | |||
| return self.learnware_searcher(user_info, **kwargs) | |||
| @@ -79,8 +110,9 @@ class LearnwareMarket: | |||
| def delete_learnware(self, id: str, **kwargs) -> bool: | |||
| return self.learnware_organizer.delete_learnware(id, **kwargs) | |||
| def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool: | |||
| return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, **kwargs) | |||
| def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| check_status = self.check_learnware(zip_path, semantic_spec, checker_names) | |||
| return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) | |||
| def get_learnware_ids(self, top: int = None, **kwargs): | |||
| return self.learnware_organizer.get_learnware_ids(top, **kwargs) | |||
| @@ -99,12 +131,11 @@ class LearnwareMarket: | |||
| class BaseOrganizer: | |||
| def __init__(self, market_id=None, checker: BaseChecker = None): | |||
| self.reset(market_id=market_id, checker=checker) | |||
| def __init__(self, market_id=None): | |||
| self.reset(market_id=market_id) | |||
| def reset(self, market_id=None, checker: BaseChecker = None, **kwargs): | |||
| def reset(self, market_id=None, **kwargs): | |||
| self.market_id = market_id | |||
| self.checker = checker | |||
| def reload_market(self, rebuild=False, **kwargs) -> bool: | |||
| """Reload the learnware organizer when server restared. | |||
| @@ -117,7 +148,7 @@ class BaseOrganizer: | |||
| raise NotImplementedError("reload market is Not Implemented in BaseOrganizer") | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]: | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, bool]: | |||
| """Add a learnware into the market. | |||
| .. note:: | |||
| @@ -167,7 +198,7 @@ class BaseOrganizer: | |||
| """ | |||
| raise NotImplementedError("delete learnware is Not Implemented in BaseOrganizer") | |||
| def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool: | |||
| def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, check_status: int) -> bool: | |||
| """ | |||
| Update Learnware with id and content to be updated. | |||
| @@ -1,3 +1,3 @@ | |||
| from .organizer import EasyOrganizer | |||
| from .checker import EasyChecker | |||
| from .searcher import EasySearcher | |||
| from .checker import EasySemanticChecker, EasyStatisticalChecker | |||
| @@ -3,71 +3,113 @@ import numpy as np | |||
| import torch | |||
| from ..base import BaseChecker | |||
| from ...config import C | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_checker", "INFO") | |||
| class EasyChecker(BaseChecker): | |||
| class EasySemanticChecker(BaseChecker): | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| try: | |||
| for key in C["semantic_specs"]: | |||
| value = semantic_spec[key]["Values"] | |||
| valid_type = C["semantic_specs"][key]["Type"] | |||
| assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch" | |||
| if valid_type == "Class": | |||
| valid_list = C["semantic_specs"][key]["Values"] | |||
| assert len(value) == 1, f"{key} must be unique" | |||
| assert value[0] in valid_list, f"{key} must be in {valid_list}" | |||
| elif valid_type == "Tag": | |||
| valid_list = C["semantic_specs"][key]["Values"] | |||
| assert len(value) >= 1, f"{key} cannot be empty" | |||
| for v in value: | |||
| assert v in valid_list, f"{key} must be in {valid_list}" | |||
| elif valid_type == "String": | |||
| assert isinstance(value, str), f"{key} must be string" | |||
| assert len(value) >= 1, f"{key} cannot be empty" | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| assert semantic_spec["Input"] is not None, "Lack of input semantics" | |||
| dim = semantic_spec["Input"]["Dimension"] | |||
| for k, v in semantic_spec["Input"]["Description"].items(): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: | |||
| assert semantic_spec["Output"] is not None, "Lack of output semantics" | |||
| dim = semantic_spec["Output"]["Dimension"] | |||
| for k, v in semantic_spec["Output"]["Description"].items(): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| return self.NONUSABLE_LEARNWARE | |||
| except Exception as err: | |||
| logger.warning(f"semantic_specification is not valid due to {err}!") | |||
| return self.INVALID_LEARNWARE | |||
| class EasyStatisticalChecker(BaseChecker): | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| try: | |||
| # check model instantiation | |||
| # Check model instantiation | |||
| learnware.instantiate_model() | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") | |||
| return self.NONUSABLE_LEARNWARE | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | |||
| return self.INVALID_LEARNWARE | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # check input shape | |||
| # Check input shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| input_shape = (semantic_spec["Input"]["Dimension"],) | |||
| else: | |||
| input_shape = learnware_model.input_shape | |||
| pass | |||
| # check rkme dimension | |||
| # Check rkme dimension | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") | |||
| if stat_spec is not None: | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") | |||
| return self.NONUSABLE_LEARNWARE | |||
| pass | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | |||
| return self.INVALID_LEARNWARE | |||
| inputs = np.random.randn(10, *input_shape) | |||
| outputs = learnware.predict(inputs) | |||
| # check output | |||
| # Check output | |||
| if outputs.ndim == 1: | |||
| outputs = outputs.reshape(-1, 1) | |||
| pass | |||
| if outputs.shape[1:] != learnware_model.output_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") | |||
| return self.INVALID_LEARNWARE | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): | |||
| # check output type | |||
| # Check output type | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| if not isinstance(outputs, np.ndarray): | |||
| logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") | |||
| return self.NONUSABLE_LEARNWARE | |||
| logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!") | |||
| return self.INVALID_LEARNWARE | |||
| # check output shape | |||
| # Check output shape | |||
| output_dim = int(semantic_spec["Output"]["Dimension"]) | |||
| if outputs[0].shape[0] != output_dim: | |||
| logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") | |||
| return self.NONUSABLE_LEARNWARE | |||
| pass | |||
| else: | |||
| if outputs.shape[1:] != learnware_model.output_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") | |||
| return self.NONUSABLE_LEARNWARE | |||
| logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") | |||
| return self.INVALID_LEARNWARE | |||
| except Exception as e: | |||
| logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") | |||
| return self.NONUSABLE_LEARNWARE | |||
| logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") | |||
| return self.INVALID_LEARNWARE | |||
| return self.USABLE_LEARWARE | |||
| return self.USABLE_LEARWARE | |||
| @@ -13,7 +13,6 @@ from shutil import copyfile, rmtree | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from .database_ops import DatabaseOperations | |||
| from .checker import EasyChecker | |||
| from ..base import LearnwareMarket, BaseUserInfo | |||
| @@ -95,42 +94,6 @@ class EasyOrganizer(BaseOrganizer): | |||
| """ | |||
| semantic_spec = copy.deepcopy(semantic_spec) | |||
| if not os.path.exists(zip_path): | |||
| logger.warning("Zip Path NOT Found! Fail to add learnware.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| try: | |||
| if len(semantic_spec["Data"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Data.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| if len(semantic_spec["Task"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Task.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| if len(semantic_spec["Library"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Device.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| if len(semantic_spec["Name"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please provide Name.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| if len(semantic_spec["Description"]["Values"]) == 0 and len(semantic_spec["Scenario"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please provide Scenario or Description.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| if ( | |||
| semantic_spec["Data"]["Type"] != "Class" | |||
| or semantic_spec["Task"]["Type"] != "Class" | |||
| or semantic_spec["Library"]["Type"] != "Class" | |||
| or semantic_spec["Scenario"]["Type"] != "Tag" | |||
| or semantic_spec["Name"]["Type"] != "String" | |||
| or semantic_spec["Description"]["Type"] != "String" | |||
| ): | |||
| logger.warning("Illegal semantic specification, please provide the right type.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| except: | |||
| print(semantic_spec) | |||
| logger.warning("Illegal semantic specification, some keys are missing.") | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| logger.info("Get new learnware from %s" % (zip_path)) | |||
| id = id if id is not None else "%08d" % (self.count) | |||
| @@ -152,12 +115,12 @@ class EasyOrganizer(BaseOrganizer): | |||
| rmtree(target_folder_dir) | |||
| except: | |||
| pass | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| return None, BaseChecker.INVALID_LEARNWARE | |||
| if new_learnware is None: | |||
| return None, EasyChecker.INVALID_LEARNWARE | |||
| return None, BaseChecker.INVALID_LEARNWARE | |||
| learnwere_status = check_status if check_status is not None else self.checker(new_learnware) | |||
| learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE | |||
| self.dbops.add_learnware( | |||
| id=id, | |||
| @@ -227,7 +190,7 @@ class EasyOrganizer(BaseOrganizer): | |||
| assert ( | |||
| zip_path is None and semantic_spec is None | |||
| ), f"at least one of 'zip_path' and 'semantic_spec' should not be None when update learnware" | |||
| assert check_status != EasyChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE" | |||
| assert check_status != BaseChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE" | |||
| if zip_path is None and check_status is not None: | |||
| logger.warning("check_status will be ignored when zip_path is None for learnware update") | |||
| @@ -252,12 +215,12 @@ class EasyOrganizer(BaseOrganizer): | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=tempdir | |||
| ) | |||
| except Exception: | |||
| return EasyChecker.INVALID_LEARNWARE | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| if new_learnware is None: | |||
| return EasyChecker.INVALID_LEARNWARE | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| learnwere_status = self.checker.check_learnware(new_learnware) | |||
| learnwere_status = BaseChecker.NONUSABLE_LEARNWARE | |||
| else: | |||
| learnwere_status = self.use_flags[id] if zip_path is None else check_status | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import List | |||
| from ..evolve.organizer import EvolvedOrganizer | |||
| from ..anchor.organizer import AnchoredOrganizer, AnchoredUserInfo | |||
| from ..evolve import EvolvedOrganizer | |||
| from ..anchor import AnchoredOrganizer, AnchoredUserInfo | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("evolve_anchor_organizer") | |||
| @@ -1,11 +1,11 @@ | |||
| from .base import LearnwareMarket | |||
| from .easy2 import EasyChecker, EasyOrganizer, EasySearcher | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker | |||
| MARKET_CONFIG = { | |||
| "easy": { | |||
| "organizer": EasyOrganizer(), | |||
| "checker": EasyChecker(), | |||
| "searcher": EasySearcher(), | |||
| "checker_list": [EasySemanticChecker(), EasyStatisticalChecker()], | |||
| } | |||
| } | |||
| @@ -14,7 +14,7 @@ def instatiate_learnware_market(market_id, name="easy", **kwargs): | |||
| return LearnwareMarket( | |||
| market_id=market_id, | |||
| organizer=MARKET_CONFIG[name]["organizer"], | |||
| checker=MARKET_CONFIG[name]["checker"], | |||
| searcher=MARKET_CONFIG[name]["searcher"], | |||
| checker_list=MARKET_CONFIG[name]["checker_list"], | |||
| **kwargs | |||
| ) | |||