From 4e4912cf5a9f66de68a598366df8166cf250f372 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 24 Oct 2023 22:49:08 +0800 Subject: [PATCH] [ENH] add organizer, searcher, checker for market --- learnware/market/base.py | 79 ++++++++++++++++++------------ learnware/market/easy/__init__.py | 7 +++ learnware/market/easy/checker.py | 72 +++++++++++++++++++++++++++ learnware/market/easy/organizer.py | 11 +++++ learnware/market/easy/searcher.py | 8 +++ learnware/market/searcher.py | 10 ++++ 6 files changed, 155 insertions(+), 32 deletions(-) create mode 100644 learnware/market/easy/__init__.py create mode 100644 learnware/market/easy/checker.py create mode 100644 learnware/market/easy/organizer.py create mode 100644 learnware/market/easy/searcher.py create mode 100644 learnware/market/searcher.py diff --git a/learnware/market/base.py b/learnware/market/base.py index e03e3c4..3998612 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -1,11 +1,14 @@ import os +import torch +import traceback import numpy as np -import pandas as pd -from typing import Tuple, Any, List, Union, Dict + +from typing import Tuple, Any, List, Union from ..learnware import Learnware -from ..specification import RKMEStatSpecification +from ..logger import get_module_logger +logger = get_module_logger("market_base", "INFO") class BaseUserInfo: """User Information for searching learnware""" @@ -43,19 +46,13 @@ class BaseUserInfo: class BaseMarket: """Base interface for market, it provide the interface of search/add/detele/update learnwares""" - def __init__(self, market_id: str = None): + def __init__(self, market_id: str = None, checker: 'LearnwareChecker' = None): self.market_id = market_id + + self.learnware_checker = LearnwareChecker() if checker is None else checker - def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: + def reload_market(self, **kwargs) -> bool: """Reload the market when server restared. - - Parameters - ---------- - market_path : str - Directory for market data. '_IP_:_port_' for loading from database. - semantic_spec_list_path : str - Directory for available semantic_spec. Should be a json file. - Returns ------- bool @@ -64,8 +61,7 @@ class BaseMarket: raise NotImplementedError("reload market is Not Implemented") - @classmethod - def check_learnware(cls, learnware: Learnware) -> bool: + def check_learnware(self, learnware: Learnware) -> bool: """Check the utility of a learnware Parameters @@ -77,7 +73,7 @@ class BaseMarket: bool A flag indicating whether the learnware can be accepted. """ - return True + return self.learnware_checker(learnware) def add_learnware( self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str @@ -221,9 +217,7 @@ class LearnwareOrganizer: raise NotImplementedError("reload market is Not Implemented") - def add_learnware( - self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str - ) -> Tuple[str, bool]: + def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]: """Add a learnware into the market. .. note:: @@ -233,22 +227,17 @@ class LearnwareOrganizer: Parameters ---------- - learnware_name : str - Name of new learnware. - model_path : str + zip_path : str Filepath for learnware model, a zipped file. - stat_spec_path : str - Filepath for statistical specification, a '.npy' file. - How to pass parameters requires further discussion. semantic_spec : dict semantic_spec for new learnware, in dictionary format. - desc : str - Brief desciption for new learnware. Returns ------- - Tuple[str, bool] - str indicating model_id, bool indicating whether the learnware is added successfully. + Tuple[str, int] + - str indicating model_id + - int indicating what the flag of learnware is added. + Raises ------ @@ -280,7 +269,33 @@ class LearnwareOrganizer: raise NotImplementedError("delete learnware is Not Implemented") class LearnwareSearcher: - def __init__(self, learnware_organizor): + def __init__(self, organizer): + self.learnware_organizer = organizer + + def __call__(self, user_info: BaseUserInfo): + raise NotImplementedError("'__call__' method is not implemented in LearnwareSearcher") - def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: - pass \ No newline at end of file + +class LearnwareChecker: + INVALID_LEARNWARE = -1 + NONUSABLE_LEARNWARE = 0 + USABLE_LEARWARE = 1 + + @classmethod + def __call__(cls, learnware: Learnware) -> int: + """Check the utility of a learnware + + Parameters + ---------- + learnware : Learnware + + Returns + ------- + int + A flag indicating whether the learnware can be accepted. + - The INVALID_LEARNWARE denotes the learnware does not pass the check + - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency + - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction + """ + + raise NotImplementedError("'__call__' method is not implemented in LearnwareChecker") \ No newline at end of file diff --git a/learnware/market/easy/__init__.py b/learnware/market/easy/__init__.py new file mode 100644 index 0000000..96f0a34 --- /dev/null +++ b/learnware/market/easy/__init__.py @@ -0,0 +1,7 @@ +from ..base import LearnwareSearcher, LearnwareOrganizer + +class EasySearcher(LearnwareSearcher): + pass + +class EasyOrganizer(LearnwareOrganizer): + pass \ No newline at end of file diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py new file mode 100644 index 0000000..88d4266 --- /dev/null +++ b/learnware/market/easy/checker.py @@ -0,0 +1,72 @@ +import traceback + +from ..base import LearnwareChecker +from ...logger import get_module_logger + +logger = get_module_logger("easy_checker", "INFO") + +class EasyChecker(LearnwareChecker): + + @classmethod + def __call__(cls, learnware): + semantic_spec = learnware.get_specification().get_semantic_spec() + + try: + # check model instantiation + learnware.instantiate_model() + + except Exception as e: + traceback.print_exc() + logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") + return cls.NONUSABLE_LEARNWARE + + try: + learnware_model = learnware.get_model() + + # check input shape + if semantic_spec["Data"]["Values"][0] == "Table": + input_shape = (semantic_spec["Input"]["Dimension"],) + else: + input_shape = learnware_model.input_shape + pass + + # check rkme dimension + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + if stat_spec is not None: + if stat_spec.get_z().shape[1:] != input_shape: + logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") + return cls.NONUSABLE_LEARNWARE + pass + + inputs = np.random.randn(10, *input_shape) + outputs = learnware.predict(inputs) + + # check output + if outputs.ndim == 1: + outputs = outputs.reshape(-1, 1) + pass + + if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): + # check output type + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if not isinstance(outputs, np.ndarray): + logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") + return cls.NONUSABLE_LEARNWARE + + # check output shape + output_dim = int(semantic_spec["Output"]["Dimension"]) + if outputs[0].shape[0] != output_dim: + logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") + return cls.NONUSABLE_LEARNWARE + pass + else: + if outputs.shape[1:] != learnware_model.output_shape: + logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") + return cls.NONUSABLE_LEARNWARE + + except Exception as e: + logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") + return cls.NONUSABLE_LEARNWARE + + return cls.USABLE_LEARWARE diff --git a/learnware/market/easy/organizer.py b/learnware/market/easy/organizer.py new file mode 100644 index 0000000..cdbaf69 --- /dev/null +++ b/learnware/market/easy/organizer.py @@ -0,0 +1,11 @@ +import traceback + +from ..base import LearnwareOrganizer +from ...logger import get_module_logger + +logger = get_module_logger("easy_organizer") + + +class EasyOrganizer(LearnwareOrganizer): + + \ No newline at end of file diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py new file mode 100644 index 0000000..eb97a7f --- /dev/null +++ b/learnware/market/easy/searcher.py @@ -0,0 +1,8 @@ +from ..base import LearnwareSearcher +from ...logger import get_module_logger + +logger = get_module_logger('easy_seacher') + +class EasySearcher(LearnwareSearcher): + pass + \ No newline at end of file diff --git a/learnware/market/searcher.py b/learnware/market/searcher.py new file mode 100644 index 0000000..09e6f30 --- /dev/null +++ b/learnware/market/searcher.py @@ -0,0 +1,10 @@ + + +from typing import Tuple, Any, List + +from .base import BaseUserInfo +from ..learnware import Learnware +from ..logger import get_module_logger + +logger = get_module_logger('model') +