From 087659748270fa5eee34cd7c6097c91ac7a34d82 Mon Sep 17 00:00:00 2001 From: Gene Date: Sat, 28 Oct 2023 21:44:48 +0800 Subject: [PATCH] [MNT] change single checker to multiple checker --- learnware/market/__init__.py | 2 +- learnware/market/base.py | 67 +++++++++++---- learnware/market/easy2/__init__.py | 2 +- learnware/market/easy2/checker.py | 94 +++++++++++++++------ learnware/market/easy2/organizer.py | 51 ++--------- learnware/market/evolve_anchor/organizer.py | 4 +- learnware/market/module.py | 6 +- 7 files changed, 131 insertions(+), 95 deletions(-) diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index bd939f5..bc6137e 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -2,7 +2,7 @@ from .anchor import AnchoredUserInfo, AnchoredOrganizer from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher from .evolve_anchor import EvolveAnchoredOrganizer from .evolve import EvolvedOrganizer -from .easy2 import EasyChecker, EasyOrganizer, EasySearcher +from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker from .hetergeneous import HeterogeneousOrganizer, MappingFunction from .easy import EasyMarket diff --git a/learnware/market/base.py b/learnware/market/base.py index dc91a9f..1bc2d2d 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -1,11 +1,12 @@ import os import torch +import tempfile import traceback import numpy as np from typing import Tuple, Any, List, Union -from ..learnware import Learnware +from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger logger = get_module_logger("market_base", "INFO") @@ -51,27 +52,57 @@ class LearnwareMarket: self, market_id: str = None, organizer: "BaseOrganizer" = None, - checker: "BaseChecker" = None, searcher: "BaseSearcher" = None, + checker_list: List["BaseChecker"] = None, rebuild=False, ): self.market_id = market_id self.learnware_organizer = BaseOrganizer() if organizer is None else organizer - self.learnware_checker = BaseChecker() if checker is None else checker - self.learnware_checker.reset(organizer=self.learnware_organizer) - self.learnware_organizer.reset(market_id=market_id, checker=self.learnware_checker) + self.learnware_organizer.reset(market_id=market_id) self.learnware_organizer.reload_market(rebuild=rebuild) self.learnware_searcher = BaseSearcher() if searcher is None else searcher self.learnware_searcher.reset(organizer=self.learnware_organizer) + + if checker_list is None: + self.learnware_checker = {"BaseChecker": BaseChecker()} + else: + self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list} + for name, checker in self.learnware_checker.items(): + checker.reset(organizer=self.learnware_organizer) def reload_market(self, **kwargs) -> bool: self.learnware_organizer.reload_market(**kwargs) - def check_learnware(self, learnware: Learnware, **kwargs) -> bool: - return self.learnware_checker(learnware, **kwargs) - - def add_learnware(self, zip_path: str, semantic_spec: dict, **kwargs) -> Tuple[str, bool]: - return self.learnware_organizer.add_learnware(zip_path, semantic_spec, **kwargs) + def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: + try: + with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: + with zipfile.ZipFile(zip_path, mode="r") as z_file: + z_file.extractall(tempdir) + + pending_learnware = get_learnware_from_dirpath( + id="pending", semantic_spec=semantic_specification, learnware_dirpath=tempdir + ) + + final_status = BaseChecker.INVALID_LEARNWARE + checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names + + for name in checker_names: + checker = self.learnware_checker[name] + check_status = checker(pending_learnware) + final_status = max(final_status, check_status) + + if check_status == BaseChecker.INVALID_LEARNWARE: + return BaseChecker.INVALID_LEARNWARE + + return final_status + + except Exception as err: + logger.warning(f"Check learnware failed! Due to {err}.") + return BaseChecker.INVALID_LEARNWARE + + def add_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> Tuple[str, bool]: + check_status = self.check_learnware(zip_path, semantic_spec, checker_names) + return self.learnware_organizer.add_learnware(zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: return self.learnware_searcher(user_info, **kwargs) @@ -79,8 +110,9 @@ class LearnwareMarket: def delete_learnware(self, id: str, **kwargs) -> bool: return self.learnware_organizer.delete_learnware(id, **kwargs) - def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool: - return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, **kwargs) + def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: + check_status = self.check_learnware(zip_path, semantic_spec, checker_names) + return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) def get_learnware_ids(self, top: int = None, **kwargs): return self.learnware_organizer.get_learnware_ids(top, **kwargs) @@ -99,12 +131,11 @@ class LearnwareMarket: class BaseOrganizer: - def __init__(self, market_id=None, checker: BaseChecker = None): - self.reset(market_id=market_id, checker=checker) + def __init__(self, market_id=None): + self.reset(market_id=market_id) - def reset(self, market_id=None, checker: BaseChecker = None, **kwargs): + def reset(self, market_id=None, **kwargs): self.market_id = market_id - self.checker = checker def reload_market(self, rebuild=False, **kwargs) -> bool: """Reload the learnware organizer when server restared. @@ -117,7 +148,7 @@ class BaseOrganizer: raise NotImplementedError("reload market is Not Implemented in BaseOrganizer") - def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]: + def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, bool]: """Add a learnware into the market. .. note:: @@ -167,7 +198,7 @@ class BaseOrganizer: """ raise NotImplementedError("delete learnware is Not Implemented in BaseOrganizer") - def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool: + def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, check_status: int) -> bool: """ Update Learnware with id and content to be updated. diff --git a/learnware/market/easy2/__init__.py b/learnware/market/easy2/__init__.py index 2ab8c48..2178119 100644 --- a/learnware/market/easy2/__init__.py +++ b/learnware/market/easy2/__init__.py @@ -1,3 +1,3 @@ from .organizer import EasyOrganizer -from .checker import EasyChecker from .searcher import EasySearcher +from .checker import EasySemanticChecker, EasyStatisticalChecker diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 062f9ad..25ee452 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -3,71 +3,113 @@ import numpy as np import torch from ..base import BaseChecker +from ...config import C from ...logger import get_module_logger logger = get_module_logger("easy_checker", "INFO") -class EasyChecker(BaseChecker): +class EasySemanticChecker(BaseChecker): + def __call__(self, learnware): + semantic_spec = learnware.get_specification().get_semantic_spec() + try: + for key in C["semantic_specs"]: + value = semantic_spec[key]["Values"] + valid_type = C["semantic_specs"][key]["Type"] + assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch" + + if valid_type == "Class": + valid_list = C["semantic_specs"][key]["Values"] + assert len(value) == 1, f"{key} must be unique" + assert value[0] in valid_list, f"{key} must be in {valid_list}" + + elif valid_type == "Tag": + valid_list = C["semantic_specs"][key]["Values"] + assert len(value) >= 1, f"{key} cannot be empty" + for v in value: + assert v in valid_list, f"{key} must be in {valid_list}" + + elif valid_type == "String": + assert isinstance(value, str), f"{key} must be string" + assert len(value) >= 1, f"{key} cannot be empty" + + if semantic_spec["Data"]["Values"][0] == "Table": + assert semantic_spec["Input"] is not None, "Lack of input semantics" + dim = semantic_spec["Input"]["Dimension"] + for k, v in semantic_spec["Input"]["Description"].items(): + assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" + assert isinstance(v, str), "Description must be string" + + if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: + assert semantic_spec["Output"] is not None, "Lack of output semantics" + dim = semantic_spec["Output"]["Dimension"] + for k, v in semantic_spec["Output"]["Description"].items(): + assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" + assert isinstance(v, str), "Description must be string" + + return self.NONUSABLE_LEARNWARE + + except Exception as err: + logger.warning(f"semantic_specification is not valid due to {err}!") + return self.INVALID_LEARNWARE + + +class EasyStatisticalChecker(BaseChecker): def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() try: - # check model instantiation + # Check model instantiation learnware.instantiate_model() except Exception as e: traceback.print_exc() - logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") - return self.NONUSABLE_LEARNWARE + logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") + return self.INVALID_LEARNWARE try: learnware_model = learnware.get_model() - # check input shape + # Check input shape if semantic_spec["Data"]["Values"][0] == "Table": input_shape = (semantic_spec["Input"]["Dimension"],) else: input_shape = learnware_model.input_shape - pass - # check rkme dimension + # Check rkme dimension stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") if stat_spec is not None: if stat_spec.get_z().shape[1:] != input_shape: - logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") - return self.NONUSABLE_LEARNWARE - pass + logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") + return self.INVALID_LEARNWARE inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) - # check output + # Check output if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) - pass + + if outputs.shape[1:] != learnware_model.output_shape: + logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") + return self.INVALID_LEARNWARE if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): - # check output type + # Check output type if isinstance(outputs, torch.Tensor): outputs = outputs.detach().cpu().numpy() if not isinstance(outputs, np.ndarray): - logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") - return self.NONUSABLE_LEARNWARE + logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!") + return self.INVALID_LEARNWARE - # check output shape + # Check output shape output_dim = int(semantic_spec["Output"]["Dimension"]) if outputs[0].shape[0] != output_dim: - logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") - return self.NONUSABLE_LEARNWARE - pass - else: - if outputs.shape[1:] != learnware_model.output_shape: - logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") - return self.NONUSABLE_LEARNWARE + logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") + return self.INVALID_LEARNWARE except Exception as e: - logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") - return self.NONUSABLE_LEARNWARE + logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") + return self.INVALID_LEARNWARE - return self.USABLE_LEARWARE + return self.USABLE_LEARWARE \ No newline at end of file diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 2f2c062..55780e3 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -13,7 +13,6 @@ from shutil import copyfile, rmtree from typing import Tuple, Any, List, Union, Dict from .database_ops import DatabaseOperations -from .checker import EasyChecker from ..base import LearnwareMarket, BaseUserInfo @@ -95,42 +94,6 @@ class EasyOrganizer(BaseOrganizer): """ semantic_spec = copy.deepcopy(semantic_spec) - - if not os.path.exists(zip_path): - logger.warning("Zip Path NOT Found! Fail to add learnware.") - return None, EasyChecker.INVALID_LEARNWARE - - try: - if len(semantic_spec["Data"]["Values"]) == 0: - logger.warning("Illegal semantic specification, please choose Data.") - return None, EasyChecker.INVALID_LEARNWARE - if len(semantic_spec["Task"]["Values"]) == 0: - logger.warning("Illegal semantic specification, please choose Task.") - return None, EasyChecker.INVALID_LEARNWARE - if len(semantic_spec["Library"]["Values"]) == 0: - logger.warning("Illegal semantic specification, please choose Device.") - return None, EasyChecker.INVALID_LEARNWARE - if len(semantic_spec["Name"]["Values"]) == 0: - logger.warning("Illegal semantic specification, please provide Name.") - return None, EasyChecker.INVALID_LEARNWARE - if len(semantic_spec["Description"]["Values"]) == 0 and len(semantic_spec["Scenario"]["Values"]) == 0: - logger.warning("Illegal semantic specification, please provide Scenario or Description.") - return None, EasyChecker.INVALID_LEARNWARE - if ( - semantic_spec["Data"]["Type"] != "Class" - or semantic_spec["Task"]["Type"] != "Class" - or semantic_spec["Library"]["Type"] != "Class" - or semantic_spec["Scenario"]["Type"] != "Tag" - or semantic_spec["Name"]["Type"] != "String" - or semantic_spec["Description"]["Type"] != "String" - ): - logger.warning("Illegal semantic specification, please provide the right type.") - return None, EasyChecker.INVALID_LEARNWARE - except: - print(semantic_spec) - logger.warning("Illegal semantic specification, some keys are missing.") - return None, EasyChecker.INVALID_LEARNWARE - logger.info("Get new learnware from %s" % (zip_path)) id = id if id is not None else "%08d" % (self.count) @@ -152,12 +115,12 @@ class EasyOrganizer(BaseOrganizer): rmtree(target_folder_dir) except: pass - return None, EasyChecker.INVALID_LEARNWARE + return None, BaseChecker.INVALID_LEARNWARE if new_learnware is None: - return None, EasyChecker.INVALID_LEARNWARE + return None, BaseChecker.INVALID_LEARNWARE - learnwere_status = check_status if check_status is not None else self.checker(new_learnware) + learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE self.dbops.add_learnware( id=id, @@ -227,7 +190,7 @@ class EasyOrganizer(BaseOrganizer): assert ( zip_path is None and semantic_spec is None ), f"at least one of 'zip_path' and 'semantic_spec' should not be None when update learnware" - assert check_status != EasyChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE" + assert check_status != BaseChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE" if zip_path is None and check_status is not None: logger.warning("check_status will be ignored when zip_path is None for learnware update") @@ -252,12 +215,12 @@ class EasyOrganizer(BaseOrganizer): id=id, semantic_spec=semantic_spec, learnware_dirpath=tempdir ) except Exception: - return EasyChecker.INVALID_LEARNWARE + return BaseChecker.INVALID_LEARNWARE if new_learnware is None: - return EasyChecker.INVALID_LEARNWARE + return BaseChecker.INVALID_LEARNWARE - learnwere_status = self.checker.check_learnware(new_learnware) + learnwere_status = BaseChecker.NONUSABLE_LEARNWARE else: learnwere_status = self.use_flags[id] if zip_path is None else check_status diff --git a/learnware/market/evolve_anchor/organizer.py b/learnware/market/evolve_anchor/organizer.py index 1e8173e..04e9779 100644 --- a/learnware/market/evolve_anchor/organizer.py +++ b/learnware/market/evolve_anchor/organizer.py @@ -1,7 +1,7 @@ from typing import List -from ..evolve.organizer import EvolvedOrganizer -from ..anchor.organizer import AnchoredOrganizer, AnchoredUserInfo +from ..evolve import EvolvedOrganizer +from ..anchor import AnchoredOrganizer, AnchoredUserInfo from ...logger import get_module_logger logger = get_module_logger("evolve_anchor_organizer") diff --git a/learnware/market/module.py b/learnware/market/module.py index 221650d..57821cd 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -1,11 +1,11 @@ from .base import LearnwareMarket -from .easy2 import EasyChecker, EasyOrganizer, EasySearcher +from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker MARKET_CONFIG = { "easy": { "organizer": EasyOrganizer(), - "checker": EasyChecker(), "searcher": EasySearcher(), + "checker_list": [EasySemanticChecker(), EasyStatisticalChecker()], } } @@ -14,7 +14,7 @@ def instatiate_learnware_market(market_id, name="easy", **kwargs): return LearnwareMarket( market_id=market_id, organizer=MARKET_CONFIG[name]["organizer"], - checker=MARKET_CONFIG[name]["checker"], searcher=MARKET_CONFIG[name]["searcher"], + checker_list=MARKET_CONFIG[name]["checker_list"], **kwargs )