| @@ -178,7 +178,7 @@ For example, the following code is designed to work with Reduced Set Kernel Embe | |||
| ```python | |||
| import learnware.specification as specification | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | |||
| user_info = BaseUserInfo( | |||
| semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} | |||
| @@ -11,7 +11,7 @@ Here you can find all ``learnware`` interfaces. | |||
| Market | |||
| ==================== | |||
| .. autoclass:: learnware.market.BaseMarket | |||
| .. autoclass:: learnware.market.LearnwareMarket | |||
| :members: | |||
| .. autoclass:: learnware.market.EasyMarket | |||
| @@ -123,7 +123,7 @@ You can search learnware by providing a statistical specification. The statistic | |||
| import learnware.specification as specification | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | |||
| specification = learnware.specification.Specification() | |||
| @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares | |||
| senarioes=[], | |||
| input_description={}, output_description={}) | |||
| stat_spec = specification.rkme.RKMEStatSpecification() | |||
| stat_spec = specification.RKMEStatSpecification() | |||
| stat_spec.load(os.path.join(unzip_path, "rkme.json")) | |||
| specification = learnware.specification.Specification() | |||
| specification.update_semantic_spec(semantic_spec) | |||
| @@ -170,7 +170,7 @@ For example, the code below executes learnware search when using Reduced Set Ker | |||
| import learnware.specification as specification | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| # unzip_path: directory for unzipped learnware zipfile | |||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | |||
| @@ -73,7 +73,7 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb | |||
| import learnware.specification as specification | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| user_spec.load(os.path.join("rkme.json")) | |||
| user_info = BaseUserInfo( | |||
| semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} | |||
| @@ -148,7 +148,7 @@ class LearnwareMarketWorkflow: | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||
| ( | |||
| @@ -1,6 +1,9 @@ | |||
| from .anchor import AnchoredUserInfo, AnchoredMarket | |||
| from .base import BaseUserInfo, BaseMarket | |||
| from .evolve_anchor import EvolvedAnchoredMarket | |||
| from .evolve import EvolvedMarket | |||
| from .anchor import AnchoredUserInfo, AnchoredOrganizer | |||
| from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | |||
| from .evolve_anchor import EvolvedAnchoredOrganizer | |||
| from .evolve import EvolvedOrganizer | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker | |||
| from .hetergeneous import HeterogeneousOrganizer, MappingFunction | |||
| from .easy import EasyMarket | |||
| from .heterogeneous_feature import HeterogeneousFeatureMarket | |||
| from .module import instatiate_learnware_market | |||
| @@ -1,140 +0,0 @@ | |||
| import os | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from ..learnware import Learnware | |||
| from .base import BaseMarket, BaseUserInfo | |||
| class AnchoredUserInfo(BaseUserInfo): | |||
| """ | |||
| User Information for searching learnware (add the anchor design) | |||
| - UserInfo contains the anchor list acquired from the market | |||
| - UserInfo can update stat_info based on anchors | |||
| """ | |||
| def __init__(self, id: str, semantic_spec: dict = dict(), stat_info: dict = dict()): | |||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||
| self.anchor_learnware_list = {} # id: Learnware | |||
| def add_anchor_learnware(self, learnware_id: str, learnware: Learnware): | |||
| """Add the anchor learnware acquired from the market | |||
| Parameters | |||
| ---------- | |||
| learnware_id : str | |||
| Id of anchor learnware | |||
| learnware : Learnware | |||
| Anchor learnware for capturing user requirements | |||
| """ | |||
| self.anchor_learnware_list[learnware_id] = learnware | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info based on anchor learnwares | |||
| Parameters | |||
| ---------- | |||
| name : str | |||
| Name of stat_info | |||
| item : Any | |||
| Statistical information calculated on anchor learnwares | |||
| """ | |||
| self.stat_info[name] = item | |||
| class AnchoredMarket(BaseMarket): | |||
| """Add the anchor design to the BaseMarket | |||
| Parameters | |||
| ---------- | |||
| BaseMarket : _type_ | |||
| Basic market version | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super(AnchoredMarket, self).__init__(*args, **kwargs) | |||
| self.anchor_learnware_list = {} # anchor_id: anchor learnware | |||
| def _update_anchor_learnware(self, anchor_id: str, anchor_learnware: Learnware): | |||
| """Update anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| anchor_id : str | |||
| Id of anchor learnware | |||
| anchor_learnware : Learnware | |||
| Anchor learnware | |||
| """ | |||
| self.anchor_learnware_list[anchor_id] = anchor_learnware | |||
| def _delete_anchor_learnware(self, anchor_id: str) -> bool: | |||
| """Delete anchor learnware in anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| anchor_id : str | |||
| Id of anchor learnware | |||
| Returns | |||
| ------- | |||
| bool | |||
| True if the target anchor learnware is deleted successfully. | |||
| Raises | |||
| ------ | |||
| Exception | |||
| Raise an excpetion when given anchor_id is NOT found in anchor_learnware_list | |||
| """ | |||
| if not anchor_id in self.anchor_learnware_list: | |||
| raise Exception("Anchor learnware id:{} NOT Found!".format(anchor_id)) | |||
| self.anchor_learnware_list.pop(anchor_id) | |||
| return True | |||
| def update_anchor_learnware_list(self, learnware_list: Dict[str, Learnware]): | |||
| """Update anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| learnware_list : Dict[str, Learnware] | |||
| Learnwares for updating anchor_learnware_list | |||
| """ | |||
| pass | |||
| def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Search anchor Learnwares from anchor_learnware_list based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : AnchoredUserInfo | |||
| - user_info with semantic specifications and statistical information | |||
| - some statistical information calculated on previous anchor learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Learnware]]: | |||
| return two items: | |||
| - first is the usage of anchor learnwares, e.g., how to use anchors to calculate some statistical information | |||
| - second is a list of anchor learnwares | |||
| """ | |||
| pass | |||
| def search_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Find helpful learnwares from learnware_list based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : AnchoredUserInfo | |||
| - user_info with semantic specifications and statistical information | |||
| - some statistical information calculated on anchor learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Any]] | |||
| return two items: | |||
| - first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided. | |||
| - second is a list of matched learnwares | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,2 @@ | |||
| from .organizer import AnchoredOrganizer | |||
| from .searcher import AnchoredUserInfo | |||
| @@ -0,0 +1,62 @@ | |||
| from typing import List, Dict, Tuple, Any | |||
| from ..easy2.organizer import EasyOrganizer | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| from ...specification import BaseStatSpecification | |||
| logger = get_module_logger("anchor_organizer") | |||
| class AnchoredOrganizer(EasyOrganizer): | |||
| """Organize learnwares and enable them to continuously evolve""" | |||
| def __init__(self, *args, **kwargs): | |||
| super(AnchoredOrganizer, self).__init__(*args, **kwargs) | |||
| self.anchor_learnware_list = {} # anchor_id: anchor learnware | |||
| def _update_anchor_learnware(self, anchor_id: str, anchor_learnware: Learnware): | |||
| """Update anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| anchor_id : str | |||
| Id of anchor learnware | |||
| anchor_learnware : Learnware | |||
| Anchor learnware | |||
| """ | |||
| self.anchor_learnware_list[anchor_id] = anchor_learnware | |||
| def _delete_anchor_learnware(self, anchor_id: str) -> bool: | |||
| """Delete anchor learnware in anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| anchor_id : str | |||
| Id of anchor learnware | |||
| Returns | |||
| ------- | |||
| bool | |||
| True if the target anchor learnware is deleted successfully. | |||
| Raises | |||
| ------ | |||
| Exception | |||
| Raise an excpetion when given anchor_id is NOT found in anchor_learnware_list | |||
| """ | |||
| if not anchor_id in self.anchor_learnware_list: | |||
| raise Exception("Anchor learnware id:{} NOT Found!".format(anchor_id)) | |||
| self.anchor_learnware_list.pop(anchor_id) | |||
| return True | |||
| def update_anchor_learnware_list(self, learnware_list: Dict[str, Learnware]): | |||
| """Update anchor_learnware_list | |||
| Parameters | |||
| ---------- | |||
| learnware_list : Dict[str, Learnware] | |||
| Learnwares for updating anchor_learnware_list | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,111 @@ | |||
| from typing import List, Dict, Tuple, Any, Union | |||
| from ..base import BaseUserInfo | |||
| from ..easy2.searcher import EasySearcher | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| logger = get_module_logger("anchor_searcher") | |||
| class AnchoredUserInfo(BaseUserInfo): | |||
| """ | |||
| User Information for searching learnware (add the anchor design) | |||
| - UserInfo contains the anchor id list acquired from the market | |||
| - UserInfo can update stat_info based on anchors | |||
| """ | |||
| def __init__( | |||
| self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None | |||
| ): | |||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||
| self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids | |||
| def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): | |||
| """Add the anchor learnware ids acquired from the market | |||
| Parameters | |||
| ---------- | |||
| learnware_ids : Union[str, List[str]] | |||
| Anchor learnware ids | |||
| """ | |||
| if isinstance(learnware_ids, str): | |||
| learnware_ids = [learnware_ids] | |||
| self.anchor_learnware_ids += learnware_ids | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info based on anchor learnwares | |||
| Parameters | |||
| ---------- | |||
| name : str | |||
| Name of stat_info | |||
| item : Any | |||
| Statistical information calculated on anchor learnwares | |||
| """ | |||
| self.stat_info[name] = item | |||
| class AnchoredSearcher(EasySearcher): | |||
| def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Search anchor Learnwares from anchor_learnware_list based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : AnchoredUserInfo | |||
| - user_info with semantic specifications and statistical information | |||
| - some statistical information calculated on previous anchor learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Learnware]]: | |||
| return two items: | |||
| - first is the usage of anchor learnwares, e.g., how to use anchors to calculate some statistical information | |||
| - second is a list of anchor learnwares | |||
| """ | |||
| pass | |||
| def search_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Find helpful learnwares from learnware_list based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : AnchoredUserInfo | |||
| - user_info with semantic specifications and statistical information | |||
| - some statistical information calculated on anchor learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Any]] | |||
| return two items: | |||
| - first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided. | |||
| - second is a list of matched learnwares | |||
| """ | |||
| pass | |||
| def __call__(self, user_info: AnchoredUserInfo, anchor_flag: bool = False) -> Tuple[Any, List[Learnware]]: | |||
| """Search learnwares with anchor marget | |||
| - if 'anchor_flag' == True, search anchor Learnwares from anchor_learnware_list based on user_info | |||
| - if 'anchor_flag' == False, find helpful learnwares from learnware_list based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : AnchoredUserInfo | |||
| - user_info with semantic specifications and statistical information | |||
| - some statistical information calculated on anchor learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Any]] | |||
| return two items: | |||
| - first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided. | |||
| - second is a list of matched learnwares | |||
| """ | |||
| if anchor_flag: | |||
| return self.search_anchor_learnware(user_info) | |||
| else: | |||
| return self.search_learnware(user_info) | |||
| @@ -1,10 +1,12 @@ | |||
| import os | |||
| import numpy as np | |||
| import pandas as pd | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| import zipfile | |||
| import tempfile | |||
| from ..learnware import Learnware | |||
| from ..specification import RKMEStatSpecification | |||
| from typing import Tuple, Any, List, Union | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("market_base", "INFO") | |||
| class BaseUserInfo: | |||
| @@ -40,47 +42,165 @@ class BaseUserInfo: | |||
| return self.stat_info.get(name, None) | |||
| class BaseMarket: | |||
| class LearnwareMarket: | |||
| """Base interface for market, it provide the interface of search/add/detele/update learnwares""" | |||
| def __init__(self, market_id: str = None): | |||
| def __init__( | |||
| self, | |||
| market_id: str = None, | |||
| organizer: "BaseOrganizer" = None, | |||
| searcher: "BaseSearcher" = None, | |||
| checker_list: List["BaseChecker"] = None, | |||
| rebuild=False, | |||
| ): | |||
| self.market_id = market_id | |||
| self.learnware_organizer = BaseOrganizer() if organizer is None else organizer | |||
| self.learnware_organizer.reset(market_id=market_id) | |||
| self.learnware_organizer.reload_market(rebuild=rebuild) | |||
| self.learnware_searcher = BaseSearcher() if searcher is None else searcher | |||
| self.learnware_searcher.reset(organizer=self.learnware_organizer) | |||
| if checker_list is None: | |||
| self.learnware_checker = {"BaseChecker": BaseChecker()} | |||
| else: | |||
| self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list} | |||
| for name, checker in self.learnware_checker.items(): | |||
| checker.reset(organizer=self.learnware_organizer) | |||
| def reload_market(self, **kwargs) -> bool: | |||
| self.learnware_organizer.reload_market(**kwargs) | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| try: | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| pending_learnware = get_learnware_from_dirpath( | |||
| id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir | |||
| ) | |||
| final_status = BaseChecker.INVALID_LEARNWARE | |||
| checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names | |||
| def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: | |||
| """Reload the market when server restared. | |||
| for name in checker_names: | |||
| checker = self.learnware_checker[name] | |||
| check_status = checker(pending_learnware) | |||
| final_status = max(final_status, check_status) | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| return final_status | |||
| except Exception as err: | |||
| logger.warning(f"Check learnware failed! Due to {err}.") | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| def add_learnware( | |||
| self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs | |||
| ) -> Tuple[str, int]: | |||
| """Add a learnware into the market. | |||
| Parameters | |||
| ---------- | |||
| market_path : str | |||
| Directory for market data. '_IP_:_port_' for loading from database. | |||
| semantic_spec_list_path : str | |||
| Directory for available semantic_spec. Should be a json file. | |||
| zip_path : str | |||
| Filepath for learnware model, a zipped file. | |||
| semantic_spec : dict | |||
| semantic_spec for new learnware, in dictionary format. | |||
| checker_names : List[str], optional | |||
| List contains checker names, by default None | |||
| Returns | |||
| ------- | |||
| bool | |||
| A flag indicating whether the market is reload successfully. | |||
| Tuple[str, int] | |||
| - str indicating model_id | |||
| - int indicating the final learnware check_status | |||
| """ | |||
| raise NotImplementedError("reload market is Not Implemented") | |||
| def check_learnware(self, learnware: Learnware) -> bool: | |||
| """Check the utility of a learnware | |||
| check_status = self.check_learnware(zip_path, semantic_spec, checker_names) | |||
| return self.learnware_organizer.add_learnware( | |||
| zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs | |||
| ) | |||
| def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: | |||
| return self.learnware_searcher(user_info, **kwargs) | |||
| def delete_learnware(self, id: str, **kwargs) -> bool: | |||
| return self.learnware_organizer.delete_learnware(id, **kwargs) | |||
| def update_learnware( | |||
| self, | |||
| id: str, | |||
| zip_path: str, | |||
| semantic_spec: dict, | |||
| checker_names: List[str] = None, | |||
| check_status: int = None, | |||
| **kwargs, | |||
| ) -> int: | |||
| """Update learnware with zip_path and semantic_specification | |||
| Parameters | |||
| ---------- | |||
| learnware : Learnware | |||
| id : str | |||
| Learnware id | |||
| zip_path : str | |||
| Filepath for learnware model, a zipped file. | |||
| semantic_spec : dict | |||
| semantic_spec for new learnware, in dictionary format. | |||
| checker_names : List[str], optional | |||
| List contains checker names, by default None. | |||
| check_status : int, optional | |||
| A flag indicating whether the learnware is usable, by default None. | |||
| Returns | |||
| ------- | |||
| int | |||
| The final learnware check_status. | |||
| """ | |||
| update_status = self.check_learnware(zip_path, semantic_spec, checker_names) | |||
| check_status = ( | |||
| update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status | |||
| ) | |||
| return self.learnware_organizer.update_learnware( | |||
| id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs | |||
| ) | |||
| def get_learnware_ids(self, top: int = None, **kwargs): | |||
| return self.learnware_organizer.get_learnware_ids(top, **kwargs) | |||
| def get_learnwares(self, top: int = None, **kwargs): | |||
| return self.learnware_organizer.get_learnwares(top, **kwargs) | |||
| def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) | |||
| def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| return self.learnware_organizer.get_learnware_by_ids(id, **kwargs) | |||
| def __len__(self): | |||
| return len(self.learnware_organizer) | |||
| class BaseOrganizer: | |||
| def __init__(self, market_id=None): | |||
| self.reset(market_id=market_id) | |||
| def reset(self, market_id=None, **kwargs): | |||
| self.market_id = market_id | |||
| def reload_market(self, rebuild=False, **kwargs) -> bool: | |||
| """Reload the learnware organizer when server restared. | |||
| Returns | |||
| ------- | |||
| bool | |||
| A flag indicating whether the learnware can be accepted. | |||
| A flag indicating whether the market is reload successfully. | |||
| """ | |||
| return True | |||
| def add_learnware( | |||
| self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str | |||
| ) -> Tuple[str, bool]: | |||
| raise NotImplementedError("reload market is Not Implemented in BaseOrganizer") | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, bool]: | |||
| """Add a learnware into the market. | |||
| .. note:: | |||
| @@ -90,22 +210,17 @@ class BaseMarket: | |||
| Parameters | |||
| ---------- | |||
| learnware_name : str | |||
| Name of new learnware. | |||
| model_path : str | |||
| zip_path : str | |||
| Filepath for learnware model, a zipped file. | |||
| stat_spec_path : str | |||
| Filepath for statistical specification, a '.npy' file. | |||
| How to pass parameters requires further discussion. | |||
| semantic_spec : dict | |||
| semantic_spec for new learnware, in dictionary format. | |||
| desc : str | |||
| Brief desciption for new learnware. | |||
| Returns | |||
| ------- | |||
| Tuple[str, bool] | |||
| str indicating model_id, bool indicating whether the learnware is added successfully. | |||
| Tuple[str, int] | |||
| - str indicating model_id | |||
| - int indicating what the flag of learnware is added. | |||
| Raises | |||
| ------ | |||
| @@ -113,26 +228,38 @@ class BaseMarket: | |||
| file for model or statistical specification not found | |||
| """ | |||
| raise NotImplementedError("add learnware is Not Implemented") | |||
| raise NotImplementedError("add learnware is Not Implemented in BaseOrganizer") | |||
| def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Search Learnware based on user_info | |||
| def delete_learnware(self, id: str) -> bool: | |||
| """Delete a learnware from market | |||
| Parameters | |||
| ---------- | |||
| user_info : BaseUserInfo | |||
| user_info with emantic specifications and statistical information | |||
| id : str | |||
| id of learnware to be deleted | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Any]] | |||
| return two items: | |||
| bool | |||
| True if the target learnware is deleted successfully. | |||
| - first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided. | |||
| - second is a list of matched learnwares | |||
| Raises | |||
| ------ | |||
| Exception | |||
| Raise an excpetion when given id is NOT found in learnware list | |||
| """ | |||
| raise NotImplementedError("delete learnware is Not Implemented in BaseOrganizer") | |||
| raise NotImplementedError("search learnware is Not Implemented") | |||
| def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, check_status: int) -> bool: | |||
| """ | |||
| Update Learnware with id and content to be updated. | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| id of target learnware. | |||
| """ | |||
| raise NotImplementedError("update learnware is Not Implemented in BaseOrganizer") | |||
| def get_learnware_by_ids(self, id: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """ | |||
| @@ -151,47 +278,103 @@ class BaseMarket: | |||
| - The returned items are search results. | |||
| - 'None' indicating the target id not found. | |||
| """ | |||
| raise NotImplementedError("search learnware is Not Implemented") | |||
| raise NotImplementedError("get_learnware_by_ids is not implemented in BaseOrganizer") | |||
| def delete_learnware(self, id: str) -> bool: | |||
| """Delete a learnware from market | |||
| def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """Get Zipped Learnware file by id | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| id of learnware to be deleted | |||
| ids : Union[str, List[str]] | |||
| Give a id or a list of ids | |||
| str: id of targer learware | |||
| List[str]: A list of ids of target learnwares | |||
| Returns | |||
| ------- | |||
| bool | |||
| True if the target learnware is deleted successfully. | |||
| Union[Learnware, List[Learnware]] | |||
| Return the path for target learnware or list of path. | |||
| None for Learnware NOT Found. | |||
| """ | |||
| raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer") | |||
| def get_learnware_ids(self, top: int = None) -> List[str]: | |||
| """get the list of learnware ids | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| the first top element to return, by default None | |||
| Raises | |||
| ------ | |||
| Exception | |||
| Raise an excpetion when given id is NOT found in learnware list | |||
| List[str] | |||
| the first top ids | |||
| """ | |||
| raise NotImplementedError("delete learnware is Not Implemented") | |||
| raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer") | |||
| def get_learnwares(self, top: int = None) -> List[Learnware]: | |||
| """get the list of learnwares | |||
| def update_learnware(self, id: str) -> bool: | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| the first top element to return, by default None | |||
| Raises | |||
| ------ | |||
| List[Learnware] | |||
| the first top learnwares | |||
| """ | |||
| Update Learnware with id and content to be updated. | |||
| Empty interface. TODO | |||
| raise NotImplementedError("get_learnwares is not implemented in BaseOrganizer") | |||
| def __len__(self): | |||
| raise NotImplementedError("__len__ is not implemented in BaseOrganizer") | |||
| class BaseSearcher: | |||
| def __init__(self, organizer: BaseOrganizer = None): | |||
| self.learnware_oganizer = organizer | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| def __call__(self, user_info: BaseUserInfo): | |||
| """Search learnwares based on user_info | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| id of target learnware. | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec and stat_info | |||
| """ | |||
| raise NotImplementedError("update learnware is Not Implemented") | |||
| raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") | |||
| def get_semantic_spec_list(self) -> dict: | |||
| """Return all semantic specifications available | |||
| class BaseChecker: | |||
| INVALID_LEARNWARE = -1 | |||
| NONUSABLE_LEARNWARE = 0 | |||
| USABLE_LEARWARE = 1 | |||
| def __init__(self, organizer: BaseOrganizer = None): | |||
| self.learnware_oganizer = organizer | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| def __call__(self, learnware: Learnware) -> int: | |||
| """Check the utility of a learnware | |||
| Parameters | |||
| ---------- | |||
| learnware : Learnware | |||
| Returns | |||
| ------- | |||
| dict | |||
| All emantic specifications in dictionary format | |||
| int | |||
| A flag indicating whether the learnware can be accepted. | |||
| - The INVALID_LEARNWARE denotes the learnware does not pass the check | |||
| - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency | |||
| - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction | |||
| """ | |||
| raise NotImplementedError("get semantic spec list is not implemented") | |||
| raise NotImplementedError("'__call__' method is not implemented in BaseChecker") | |||
| @@ -117,25 +117,25 @@ class DatabaseOperations(object): | |||
| pass | |||
| pass | |||
| def update_learnware_semantic_spec(self, learnware_id: str, semantic_spec: dict): | |||
| def delete_learnware(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| conn.execute( | |||
| text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"), | |||
| dict(id=learnware_id, semantic_spec=semantic_spec_str), | |||
| ) | |||
| conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def delete_learnware(self, id: str): | |||
| def update_learnware_semantic_specification(self, id: str, semantic_spec: dict): | |||
| with self.engine.connect() as conn: | |||
| conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| r = conn.execute( | |||
| text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"), | |||
| dict(id=id, semantic_spec=semantic_spec_str), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def update_learnware_semantic_specification(self, id: str, semantic_spec: dict): | |||
| def update_learnware_use_flag(self, id: str, semantic_spec: dict): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| r = conn.execute( | |||
| @@ -11,7 +11,7 @@ from cvxopt import solvers, matrix | |||
| from shutil import copyfile, rmtree | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from .base import BaseMarket, BaseUserInfo | |||
| from .base import LearnwareMarket, BaseUserInfo | |||
| from .database_ops import DatabaseOperations | |||
| from .. import utils | |||
| @@ -24,8 +24,8 @@ from ..specification import RKMEStatSpecification, Specification | |||
| logger = get_module_logger("market", "INFO") | |||
| class EasyMarket(BaseMarket): | |||
| """EasyMarket provide an easy and simple implementation for BaseMarket | |||
| class EasyMarket(LearnwareMarket): | |||
| """EasyMarket provide an easy and simple implementation for LearnwareMarket | |||
| - EasyMarket stores learnwares with file system and database | |||
| - EasyMarket search the learnwares with the match of semantical tag and the statistical RKME | |||
| - EasyMarket does not support the search between heterogeneous features learnwars | |||
| @@ -956,11 +956,11 @@ class EasyMarket(BaseMarket): | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def update_learnware_semantic_spec(self, learnware_id: str, semantic_spec: dict) -> bool: | |||
| def update_learnware_semantic_specification(self, learnware_id: str, semantic_spec: dict) -> bool: | |||
| """Update Learnware semantic_spec""" | |||
| # update database | |||
| self.dbops.update_learnware_semantic_spec(learnware_id=learnware_id, semantic_spec=semantic_spec) | |||
| self.dbops.update_learnware_semantic_specification(learnware_id=learnware_id, semantic_spec=semantic_spec) | |||
| # update file | |||
| folder_path = self.learnware_folder_list[learnware_id] | |||
| @@ -0,0 +1,3 @@ | |||
| from .organizer import EasyOrganizer | |||
| from .searcher import EasySearcher | |||
| from .checker import EasySemanticChecker, EasyStatisticalChecker | |||
| @@ -0,0 +1,115 @@ | |||
| import traceback | |||
| import numpy as np | |||
| import torch | |||
| from ..base import BaseChecker | |||
| from ...config import C | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_checker", "INFO") | |||
| class EasySemanticChecker(BaseChecker): | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| try: | |||
| for key in C["semantic_specs"]: | |||
| value = semantic_spec[key]["Values"] | |||
| valid_type = C["semantic_specs"][key]["Type"] | |||
| assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch" | |||
| if valid_type == "Class": | |||
| valid_list = C["semantic_specs"][key]["Values"] | |||
| assert len(value) == 1, f"{key} must be unique" | |||
| assert value[0] in valid_list, f"{key} must be in {valid_list}" | |||
| elif valid_type == "Tag": | |||
| valid_list = C["semantic_specs"][key]["Values"] | |||
| assert len(value) >= 1, f"{key} cannot be empty" | |||
| for v in value: | |||
| assert v in valid_list, f"{key} must be in {valid_list}" | |||
| elif valid_type == "String": | |||
| assert isinstance(value, str), f"{key} must be string" | |||
| assert len(value) >= 1, f"{key} cannot be empty" | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| assert semantic_spec["Input"] is not None, "Lack of input semantics" | |||
| dim = semantic_spec["Input"]["Dimension"] | |||
| for k, v in semantic_spec["Input"]["Description"].items(): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: | |||
| assert semantic_spec["Output"] is not None, "Lack of output semantics" | |||
| dim = semantic_spec["Output"]["Dimension"] | |||
| for k, v in semantic_spec["Output"]["Description"].items(): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| return self.NONUSABLE_LEARNWARE | |||
| except Exception as err: | |||
| logger.warning(f"semantic_specification is not valid due to {err}!") | |||
| return self.INVALID_LEARNWARE | |||
| class EasyStatisticalChecker(BaseChecker): | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| try: | |||
| # Check model instantiation | |||
| learnware.instantiate_model() | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | |||
| return self.INVALID_LEARNWARE | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # Check input shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| input_shape = (semantic_spec["Input"]["Dimension"],) | |||
| else: | |||
| input_shape = learnware_model.input_shape | |||
| # Check rkme dimension | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") | |||
| if stat_spec is not None: | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | |||
| return self.INVALID_LEARNWARE | |||
| inputs = np.random.randn(10, *input_shape) | |||
| outputs = learnware.predict(inputs) | |||
| # Check output | |||
| if outputs.ndim == 1: | |||
| outputs = outputs.reshape(-1, 1) | |||
| if outputs.shape[1:] != learnware_model.output_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") | |||
| return self.INVALID_LEARNWARE | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): | |||
| # Check output type | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| if not isinstance(outputs, np.ndarray): | |||
| logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!") | |||
| return self.INVALID_LEARNWARE | |||
| # Check output shape | |||
| output_dim = int(semantic_spec["Output"]["Dimension"]) | |||
| if outputs[0].shape[0] != output_dim: | |||
| logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") | |||
| return self.INVALID_LEARNWARE | |||
| except Exception as e: | |||
| logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") | |||
| return self.INVALID_LEARNWARE | |||
| return self.USABLE_LEARWARE | |||
| @@ -0,0 +1,176 @@ | |||
| from sqlalchemy.ext.declarative import declarative_base | |||
| from sqlalchemy import create_engine, text | |||
| from sqlalchemy import Column, Integer, Text, DateTime, String | |||
| import os | |||
| import json | |||
| from ...learnware import get_learnware_from_dirpath | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("database") | |||
| DeclarativeBase = declarative_base() | |||
| class Learnware(DeclarativeBase): | |||
| __tablename__ = "tb_learnware" | |||
| id = Column(String(10), primary_key=True, nullable=False) | |||
| semantic_spec = Column(Text, nullable=False) | |||
| zip_path = Column(Text, nullable=False) | |||
| folder_path = Column(Text, nullable=False) | |||
| use_flag = Column(Text, nullable=False) | |||
| pass | |||
| class DatabaseOperations(object): | |||
| def __init__(self, url: str, database_name: str): | |||
| if url.startswith("sqlite"): | |||
| url = os.path.join(url, f"{database_name}.db") | |||
| else: | |||
| url = f"{url}/{database_name}" | |||
| pass | |||
| self.url = url | |||
| self.create_database_if_not_exists(url) | |||
| pass | |||
| def create_database_if_not_exists(self, url): | |||
| database_exists = True | |||
| if url.startswith("sqlite"): | |||
| # it is sqlite | |||
| start = url.find(":///") | |||
| path = url[start + 4 :] | |||
| if os.path.exists(path): | |||
| database_exists = True | |||
| pass | |||
| else: | |||
| database_exists = False | |||
| os.makedirs(os.path.dirname(path), exist_ok=True) | |||
| pass | |||
| pass | |||
| elif self.url.startswith("postgresql"): | |||
| # it is postgresql | |||
| dbname_start = url.rfind("/") | |||
| dbname = url[dbname_start + 1 :] | |||
| url_no_dbname = url[:dbname_start] + "/postgres" | |||
| engine = create_engine(url_no_dbname) | |||
| with engine.connect() as conn: | |||
| result = conn.execute(text("SELECT datname FROM pg_database;")) | |||
| db_list = set() | |||
| for row in result.fetchall(): | |||
| db_list.add(row[0].lower()) | |||
| pass | |||
| if dbname.lower() not in db_list: | |||
| database_exists = False | |||
| conn.execution_options(isolation_level="AUTOCOMMIT").execute( | |||
| text("CREATE DATABASE {0};".format(dbname)) | |||
| ) | |||
| pass | |||
| else: | |||
| database_exists = True | |||
| pass | |||
| pass | |||
| engine.dispose() | |||
| pass | |||
| else: | |||
| raise Exception(f"Unsupported database url: {self.url}") | |||
| pass | |||
| self.engine = create_engine(url, future=True) | |||
| if not database_exists: | |||
| DeclarativeBase.metadata.create_all(self.engine) | |||
| pass | |||
| pass | |||
| def clear_learnware_table(self): | |||
| with self.engine.connect() as conn: | |||
| conn.execute(text("DELETE FROM tb_learnware;")) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def add_learnware(self, id: str, semantic_spec: dict, zip_path, folder_path, use_flag: str): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| conn.execute( | |||
| text( | |||
| ( | |||
| "INSERT INTO tb_learnware (id, semantic_spec, zip_path, folder_path, use_flag)" | |||
| "VALUES (:id, :semantic_spec, :zip_path, :folder_path, :use_flag);" | |||
| ) | |||
| ), | |||
| dict( | |||
| id=id, | |||
| semantic_spec=semantic_spec_str, | |||
| zip_path=zip_path, | |||
| folder_path=folder_path, | |||
| use_flag=use_flag, | |||
| ), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def delete_learnware(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def update_learnware_semantic_specification(self, id: str, semantic_spec: dict): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| r = conn.execute( | |||
| text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"), | |||
| dict(id=id, semantic_spec=semantic_spec_str), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def update_learnware_use_flag(self, id: str, use_flag: str): | |||
| with self.engine.connect() as conn: | |||
| r = conn.execute( | |||
| text("UPDATE tb_learnware SET use_flag=:use_flag WHERE id=:id;"), | |||
| dict(id=id, use_flag=use_flag), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def load_market(self): | |||
| with self.engine.connect() as conn: | |||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | |||
| learnware_list = {} | |||
| zip_list = {} | |||
| folder_list = {} | |||
| use_flags = {} | |||
| max_count = 0 | |||
| for id, semantic_spec, zip_path, folder_path, use_flag in cursor: | |||
| id = id.strip() | |||
| semantic_spec_dict = json.loads(semantic_spec) | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec_dict, learnware_dirpath=folder_path | |||
| ) | |||
| logger.info(f"Load learnware: {id}") | |||
| learnware_list[id] = new_learnware | |||
| # assert new_learnware is not None | |||
| zip_list[id] = zip_path | |||
| folder_list[id] = folder_path | |||
| use_flags[id] = use_flag | |||
| max_count = max(max_count, int(id)) | |||
| pass | |||
| return learnware_list, zip_list, folder_list, use_flags, max_count + 1 | |||
| pass | |||
| pass | |||
| @@ -0,0 +1,313 @@ | |||
| import os | |||
| import json | |||
| import copy | |||
| import torch | |||
| import zipfile | |||
| import traceback | |||
| import tempfile | |||
| import numpy as np | |||
| import pandas as pd | |||
| from rapidfuzz import fuzz | |||
| from cvxopt import solvers, matrix | |||
| from shutil import copyfile, rmtree | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from .database_ops import DatabaseOperations | |||
| from ..base import LearnwareMarket, BaseUserInfo | |||
| from ... import utils | |||
| from ...config import C as conf | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware, get_learnware_from_dirpath | |||
| from ...specification import RKMEStatSpecification, Specification | |||
| from ..base import BaseOrganizer, BaseChecker | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_organizer") | |||
| class EasyOrganizer(BaseOrganizer): | |||
| def reload_market(self, rebuild=False) -> bool: | |||
| """Reload the learnware organizer when server restared. | |||
| Returns | |||
| ------- | |||
| bool | |||
| A flag indicating whether the market is reload successfully. | |||
| """ | |||
| self.market_store_path = os.path.join(conf.market_root_path, self.market_id) | |||
| self.learnware_pool_path = os.path.join(self.market_store_path, "learnware_pool") | |||
| self.learnware_zip_pool_path = os.path.join(self.learnware_pool_path, "zips") | |||
| self.learnware_folder_pool_path = os.path.join(self.learnware_pool_path, "unzipped_learnwares") | |||
| self.learnware_list = {} # id: Learnware | |||
| self.learnware_zip_list = {} | |||
| self.learnware_folder_list = {} | |||
| self.use_flags = {} | |||
| self.count = 0 | |||
| self.semantic_spec_list = conf.semantic_specs | |||
| self.dbops = DatabaseOperations(conf.database_url, "market_" + self.market_id) | |||
| if rebuild: | |||
| logger.warning("Warning! You are trying to clear current database!") | |||
| try: | |||
| self.dbops.clear_learnware_table() | |||
| rmtree(self.learnware_pool_path) | |||
| except: | |||
| pass | |||
| os.makedirs(self.learnware_pool_path, exist_ok=True) | |||
| os.makedirs(self.learnware_zip_pool_path, exist_ok=True) | |||
| os.makedirs(self.learnware_folder_pool_path, exist_ok=True) | |||
| ( | |||
| self.learnware_list, | |||
| self.learnware_zip_list, | |||
| self.learnware_folder_list, | |||
| self.use_flags, | |||
| self.count, | |||
| ) = self.dbops.load_market() | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]: | |||
| """Add a learnware into the market. | |||
| Parameters | |||
| ---------- | |||
| zip_path : str | |||
| Filepath for learnware model, a zipped file. | |||
| semantic_spec : dict | |||
| semantic_spec for new learnware, in dictionary format. | |||
| check_status: int | |||
| A flag indicating whether the learnware is usable. | |||
| Returns | |||
| ------- | |||
| Tuple[str, int] | |||
| - str indicating model_id | |||
| - int indicating the final learnware check_status | |||
| """ | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| logger.warning("Learnware is invalid!") | |||
| return None, BaseChecker.INVALID_LEARNWARE | |||
| semantic_spec = copy.deepcopy(semantic_spec) | |||
| logger.info("Get new learnware from %s" % (zip_path)) | |||
| id = "%08d" % (self.count) | |||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) | |||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) | |||
| copyfile(zip_path, target_zip_dir) | |||
| with zipfile.ZipFile(target_zip_dir, "r") as z_file: | |||
| z_file.extractall(target_folder_dir) | |||
| logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir)) | |||
| try: | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| ) | |||
| except: | |||
| try: | |||
| os.remove(target_zip_dir) | |||
| rmtree(target_folder_dir) | |||
| except: | |||
| pass | |||
| return None, BaseChecker.INVALID_LEARNWARE | |||
| if new_learnware is None: | |||
| return None, BaseChecker.INVALID_LEARNWARE | |||
| learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE | |||
| self.dbops.add_learnware( | |||
| id=id, | |||
| semantic_spec=semantic_spec, | |||
| zip_path=target_zip_dir, | |||
| folder_path=target_folder_dir, | |||
| use_flag=learnwere_status, | |||
| ) | |||
| self.learnware_list[id] = new_learnware | |||
| self.learnware_zip_list[id] = target_zip_dir | |||
| self.learnware_folder_list[id] = target_folder_dir | |||
| self.use_flags[id] = learnwere_status | |||
| self.count += 1 | |||
| return id, learnwere_status | |||
| def delete_learnware(self, id: str) -> bool: | |||
| """Delete Learnware from market | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| Learnware to be deleted | |||
| Returns | |||
| ------- | |||
| bool | |||
| True for successful operation. | |||
| False for id not found. | |||
| """ | |||
| if not id in self.learnware_list: | |||
| logger.warning("Learnware id:'{}' NOT Found!".format(id)) | |||
| return False | |||
| zip_dir = self.learnware_zip_list[id] | |||
| os.remove(zip_dir) | |||
| folder_dir = self.learnware_folder_list[id] | |||
| rmtree(folder_dir) | |||
| self.learnware_list.pop(id) | |||
| self.learnware_zip_list.pop(id) | |||
| self.learnware_folder_list.pop(id) | |||
| self.use_flags.pop(id) | |||
| self.dbops.delete_learnware(id=id) | |||
| return True | |||
| def update_learnware(self, id: str, zip_path: str = None, semantic_spec: dict = None, check_status: int = None): | |||
| """Update learnware with zip_path, semantic_specification and check_status | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| Learnware id | |||
| zip_path : str, optional | |||
| Filepath for learnware model, a zipped file. | |||
| semantic_spec : dict, optional | |||
| semantic_spec for new learnware, in dictionary format. | |||
| check_status : int, optional | |||
| A flag indicating whether the learnware is usable. | |||
| Returns | |||
| ------- | |||
| int | |||
| The final learnware check_status. | |||
| """ | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| logger.warning("Learnware is invalid!") | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| if zip_path is None and semantic_spec is None and check_status is None: | |||
| logger.warning( | |||
| "At least one of 'zip_path', 'semantic_spec' and 'check_status' should not be None when update learnware" | |||
| ) | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| # Update semantic_specification | |||
| learnware_zippath = self.learnware_zip_list[id] if zip_path is None else zip_path | |||
| semantic_spec = ( | |||
| self.learnware_list[id].get_specification().get_semantic_spec() if semantic_spec is None else semantic_spec | |||
| ) | |||
| self.dbops.update_learnware_semantic_specification(id, semantic_spec) | |||
| # Update zip path | |||
| target_zip_dir = self.learnware_zip_list[id] | |||
| target_folder_dir = self.learnware_folder_list[id] | |||
| if zip_path is not None: | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, "r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| try: | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=tempdir | |||
| ) | |||
| except Exception: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| if new_learnware is None: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| copyfile(zip_path, target_zip_dir) | |||
| with zipfile.ZipFile(target_zip_dir, "r") as z_file: | |||
| z_file.extractall(target_folder_dir) | |||
| # Update check_status | |||
| self.use_flags[id] = self.use_flags[id] if check_status is None else check_status | |||
| self.dbops.update_learnware_use_flag(id, self.use_flags[id]) | |||
| # Update learnware list | |||
| self.learnware_list[id] = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| ) | |||
| return self.use_flags[id] | |||
| def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """Search learnware by id or list of ids. | |||
| Parameters | |||
| ---------- | |||
| ids : Union[str, List[str]] | |||
| Give a id or a list of ids | |||
| str: id of targer learware | |||
| List[str]: A list of ids of target learnwares | |||
| Returns | |||
| ------- | |||
| Union[Learnware, List[Learnware]] | |||
| Return target learnware or list of target learnwares. | |||
| None for Learnware NOT Found. | |||
| """ | |||
| if isinstance(ids, list): | |||
| ret = [] | |||
| for id in ids: | |||
| if id in self.learnware_list: | |||
| ret.append(self.learnware_list[id]) | |||
| else: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (id)) | |||
| ret.append(None) | |||
| return ret | |||
| else: | |||
| try: | |||
| return self.learnware_list[ids] | |||
| except: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """Get Zipped Learnware file by id | |||
| Parameters | |||
| ---------- | |||
| ids : Union[str, List[str]] | |||
| Give a id or a list of ids | |||
| str: id of targer learware | |||
| List[str]: A list of ids of target learnwares | |||
| Returns | |||
| ------- | |||
| Union[Learnware, List[Learnware]] | |||
| Return the path for target learnware or list of path. | |||
| None for Learnware NOT Found. | |||
| """ | |||
| if isinstance(ids, list): | |||
| ret = [] | |||
| for id in ids: | |||
| if id in self.learnware_zip_list: | |||
| ret.append(self.learnware_zip_list[id]) | |||
| else: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (id)) | |||
| ret.append(None) | |||
| return ret | |||
| else: | |||
| try: | |||
| return self.learnware_zip_list[ids] | |||
| except: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def get_learnware_ids(self, top: int = None) -> List[str]: | |||
| if top is None: | |||
| return list(self.learnware_list.keys()) | |||
| else: | |||
| return list(self.learnware_list.keys())[:top] | |||
| def get_learnwares(self, top: int = None) -> List[str]: | |||
| if top is None: | |||
| return list(self.learnware_list.values()) | |||
| else: | |||
| return list(self.learnware_list.values())[:top] | |||
| def __len__(self): | |||
| return len(self.learnware_list) | |||
| @@ -0,0 +1,637 @@ | |||
| import torch | |||
| import numpy as np | |||
| from rapidfuzz import fuzz | |||
| from cvxopt import solvers, matrix | |||
| from typing import Tuple, List | |||
| from .organizer import EasyOrganizer | |||
| from ..base import BaseUserInfo, BaseSearcher | |||
| from ...learnware import Learnware | |||
| from ...specification import RKMEStatSpecification | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_seacher") | |||
| class EasyExactSemanticSearcher(BaseSearcher): | |||
| def _match_semantic_spec(self, semantic_spec1, semantic_spec2): | |||
| """ | |||
| semantic_spec1: semantic spec input by user | |||
| semantic_spec2: semantic spec in database | |||
| """ | |||
| if semantic_spec1.keys() != semantic_spec2.keys(): | |||
| # sematic spec in database may contain more keys than user input | |||
| pass | |||
| name2 = semantic_spec2["Name"]["Values"].lower() | |||
| description2 = semantic_spec2["Description"]["Values"].lower() | |||
| for key in semantic_spec1.keys(): | |||
| v1 = semantic_spec1[key].get("Values", "") | |||
| v2 = semantic_spec2[key].get("Values", "") | |||
| if len(v1) == 0: | |||
| # user input is empty, no need to search | |||
| continue | |||
| if key in ("Name", "Description"): | |||
| v1 = v1.lower() | |||
| if v1 not in name2 and v1 not in description2: | |||
| return False | |||
| pass | |||
| else: | |||
| if len(v2) == 0: | |||
| # user input contains some key that is not in database | |||
| return False | |||
| if semantic_spec1[key]["Type"] == "Class": | |||
| if isinstance(v1, list): | |||
| v1 = v1[0] | |||
| if isinstance(v2, list): | |||
| v2 = v2[0] | |||
| if v1 != v2: | |||
| return False | |||
| elif semantic_spec1[key]["Type"] == "Tag": | |||
| if not (set(v1) & set(v2)): | |||
| return False | |||
| pass | |||
| pass | |||
| pass | |||
| return True | |||
| def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: | |||
| match_learnwares = [] | |||
| for learnware in learnware_list: | |||
| learnware_semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| user_semantic_spec = user_info.get_semantic_spec() | |||
| if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec): | |||
| match_learnwares.append(learnware) | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) | |||
| return match_learnwares | |||
| class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: | |||
| """Judge if tags of two semantic specs are consistent | |||
| Parameters | |||
| ---------- | |||
| semantic_spec1 : | |||
| semantic spec input by user | |||
| semantic_spec2 : | |||
| semantic spec in database | |||
| Returns | |||
| ------- | |||
| bool | |||
| consistent (True) or not consistent (False) | |||
| """ | |||
| for key in semantic_spec1.keys(): | |||
| v1 = semantic_spec1[key].get("Values", "") | |||
| v2 = semantic_spec2[key].get("Values", "") | |||
| if len(v1) == 0: | |||
| # user input is empty, no need to search | |||
| continue | |||
| if key not in "Name": | |||
| if len(v2) == 0: | |||
| # user input contains some key that is not in database | |||
| return False | |||
| if semantic_spec1[key]["Type"] == "Class": | |||
| if isinstance(v1, list): | |||
| v1 = v1[0] | |||
| if isinstance(v2, list): | |||
| v2 = v2[0] | |||
| if v1 != v2: | |||
| return False | |||
| elif semantic_spec1[key]["Type"] == "Tag": | |||
| if not (set(v1) & set(v2)): | |||
| return False | |||
| return True | |||
| def __call__( | |||
| self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0 | |||
| ) -> List[Learnware]: | |||
| """Search learnware by fuzzy matching of semantic spec | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec | |||
| max_num : int, optional | |||
| maximum number of learnwares returned, by default 50000 | |||
| min_score : float, optional | |||
| Minimum fuzzy matching score of learnwares returned, by default 30.0 | |||
| Returns | |||
| ------- | |||
| List[Learnware] | |||
| The list of returned learnwares | |||
| """ | |||
| matched_learnware_tag = [] | |||
| final_result = [] | |||
| user_semantic_spec = user_info.get_semantic_spec() | |||
| for learnware in learnware_list: | |||
| learnware_semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| if self._match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec): | |||
| matched_learnware_tag.append(learnware) | |||
| if len(matched_learnware_tag) > 0: | |||
| if "Name" in user_semantic_spec: | |||
| name_user = user_semantic_spec["Name"]["Values"].lower() | |||
| if len(name_user) > 0: | |||
| # Exact search | |||
| name_list = [ | |||
| learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() | |||
| for learnware in matched_learnware_tag | |||
| ] | |||
| des_list = [ | |||
| learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() | |||
| for learnware in matched_learnware_tag | |||
| ] | |||
| matched_learnware_exact = [] | |||
| for i in range(len(name_list)): | |||
| if name_user in name_list[i] or name_user in des_list[i]: | |||
| matched_learnware_exact.append(matched_learnware_tag[i]) | |||
| if len(matched_learnware_exact) == 0: | |||
| # Fuzzy search | |||
| matched_learnware_fuzz, fuzz_scores = [], [] | |||
| for i in range(len(name_list)): | |||
| score_name = fuzz.partial_ratio(name_user, name_list[i]) | |||
| score_des = fuzz.partial_ratio(name_user, des_list[i]) | |||
| final_score = max(score_name, score_des) | |||
| if final_score >= min_score: | |||
| matched_learnware_fuzz.append(matched_learnware_tag[i]) | |||
| fuzz_scores.append(final_score) | |||
| # Sort by score | |||
| sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[ | |||
| :max_num | |||
| ] | |||
| final_result = [matched_learnware_fuzz[idx] for idx in sort_idx] | |||
| else: | |||
| final_result = matched_learnware_exact | |||
| else: | |||
| final_result = matched_learnware_tag | |||
| else: | |||
| final_result = matched_learnware_tag | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) | |||
| return final_result | |||
| class EasyTableSearcher(BaseSearcher): | |||
| def _convert_dist_to_score( | |||
| self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 | |||
| ) -> List[float]: | |||
| """Convert mmd dist list into min_max score list | |||
| Parameters | |||
| ---------- | |||
| dist_list : List[float] | |||
| The list of mmd distances from learnware rkmes to user rkme | |||
| dist_epsilon: float | |||
| The paramter for converting mmd dist to score | |||
| min_score: float | |||
| The minimum score for maximum returned score | |||
| Returns | |||
| ------- | |||
| List[float] | |||
| The list of min_max scores of each learnware | |||
| """ | |||
| if len(dist_list) == 0: | |||
| return [] | |||
| min_dist, max_dist = min(dist_list), max(dist_list) | |||
| if min_dist == max_dist: | |||
| return [1 for dist in dist_list] | |||
| else: | |||
| max_score = (max_dist - min_dist) / (max_dist - dist_epsilon) | |||
| if min_dist < dist_epsilon: | |||
| dist_epsilon = min_dist | |||
| elif max_score < min_score: | |||
| dist_epsilon = max_dist - (max_dist - min_dist) / min_score | |||
| return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list] | |||
| def _calculate_rkme_spec_mixture_weight( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMEStatSpecification, | |||
| intermediate_K: np.ndarray = None, | |||
| intermediate_C: np.ndarray = None, | |||
| ) -> Tuple[List[float], float]: | |||
| """Calculate mixture weight for the learnware_list based on a user's rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| A list of existing learnwares | |||
| user_rkme : RKMEStatSpecification | |||
| User RKME statistical specification | |||
| intermediate_K : np.ndarray, optional | |||
| Intermediate kernel matrix K, by default None | |||
| intermediate_C : np.ndarray, optional | |||
| Intermediate inner product vector C, by default None | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], float] | |||
| The first is the list of mixture weights | |||
| The second is the mmd dist between the mixture of learnware rkmes and the user's rkme | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||
| ] | |||
| if type(intermediate_K) == np.ndarray: | |||
| K = intermediate_K | |||
| else: | |||
| K = np.zeros((learnware_num, learnware_num)) | |||
| for i in range(K.shape[0]): | |||
| K[i, i] = RKME_list[i].inner_prod(RKME_list[i]) | |||
| for j in range(i + 1, K.shape[0]): | |||
| K[i, j] = K[j, i] = RKME_list[i].inner_prod(RKME_list[j]) | |||
| if type(intermediate_C) == np.ndarray: | |||
| C = intermediate_C | |||
| else: | |||
| C = np.zeros((learnware_num, 1)) | |||
| for i in range(C.shape[0]): | |||
| C[i, 0] = user_rkme.inner_prod(RKME_list[i]) | |||
| K = torch.from_numpy(K).double().to(user_rkme.device) | |||
| C = torch.from_numpy(C).double().to(user_rkme.device) | |||
| # beta can be negative | |||
| # weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C | |||
| # beta must be nonnegative | |||
| n = K.shape[0] | |||
| P = matrix(K.cpu().numpy()) | |||
| q = matrix(-C.cpu().numpy()) | |||
| G = matrix(-np.eye(n)) | |||
| h = matrix(np.zeros((n, 1))) | |||
| A = matrix(np.ones((1, n))) | |||
| b = matrix(np.ones((1, 1))) | |||
| solvers.options["show_progress"] = False | |||
| sol = solvers.qp(P, q, G, h, A, b) | |||
| weight = np.array(sol["x"]) | |||
| weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device) | |||
| score = user_rkme.inner_prod(user_rkme) + 2 * sol["primal objective"] | |||
| return weight.detach().cpu().numpy().reshape(-1), score | |||
| def _calculate_intermediate_K_and_C( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMEStatSpecification, | |||
| intermediate_K: np.ndarray = None, | |||
| intermediate_C: np.ndarray = None, | |||
| ) -> Tuple[np.ndarray, np.ndarray]: | |||
| """Incrementally update the values of intermediate_K and intermediate_C | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares up till now | |||
| user_rkme : RKMEStatSpecification | |||
| User RKME statistical specification | |||
| intermediate_K : np.ndarray, optional | |||
| Intermediate kernel matrix K, by default None | |||
| intermediate_C : np.ndarray, optional | |||
| Intermediate inner product vector C, by default None | |||
| Returns | |||
| ------- | |||
| Tuple[np.ndarray, np.ndarray] | |||
| The first is the intermediate value of K | |||
| The second is the intermediate value of C | |||
| """ | |||
| num = intermediate_K.shape[0] - 1 | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||
| ] | |||
| for i in range(intermediate_K.shape[0]): | |||
| intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | |||
| intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) | |||
| return intermediate_K, intermediate_C | |||
| def _search_by_rkme_spec_mixture_auto( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMEStatSpecification, | |||
| max_search_num: int, | |||
| weight_cutoff: float = 0.98, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| """Select learnwares based on a total mixture ratio, then recalculate their mixture weights | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMEStatSpecification | |||
| User RKME statistical specification | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| weight_cutoff : float, optional | |||
| The ratio for selecting out the mose relevant learnwares, by default 0.9 | |||
| Returns | |||
| ------- | |||
| Tuple[float, List[float], List[Learnware]] | |||
| The first is the mixture mmd dist | |||
| The second is the list of weight | |||
| The third is the list of Learnware | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| if learnware_num == 0: | |||
| return [], [] | |||
| if learnware_num < max_search_num: | |||
| logger.warning("Available Learnware num less than search_num!") | |||
| max_search_num = learnware_num | |||
| weight, _ = self._calculate_rkme_spec_mixture_weight(learnware_list, user_rkme) | |||
| sort_by_weight_idx_list = sorted(range(learnware_num), key=lambda k: weight[k], reverse=True) | |||
| weight_sum = 0 | |||
| mixture_list = [] | |||
| for idx in sort_by_weight_idx_list: | |||
| weight_sum += weight[idx] | |||
| if weight_sum <= weight_cutoff: | |||
| mixture_list.append(learnware_list[idx]) | |||
| else: | |||
| break | |||
| if len(mixture_list) <= 1: | |||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | |||
| mixture_weight = [1] | |||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) | |||
| else: | |||
| if len(mixture_list) > max_search_num: | |||
| mixture_list = mixture_list[:max_search_num] | |||
| mixture_weight, mmd_dist = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme) | |||
| return mmd_dist, mixture_weight, mixture_list | |||
| def _filter_by_rkme_spec_single( | |||
| self, | |||
| sorted_score_list: List[float], | |||
| learnware_list: List[Learnware], | |||
| filter_score: float = 0.5, | |||
| min_num: int = 15, | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Filter search result of _search_by_rkme_spec_single | |||
| Parameters | |||
| ---------- | |||
| sorted_score_list : List[float] | |||
| The list of score transformed by mmd dist | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| filter_score: float | |||
| The learnware whose score is lower than filter_score will be filtered | |||
| min_num: int | |||
| The minimum number of returned learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware]] | |||
| the first is the list of score | |||
| the second is the list of Learnware | |||
| """ | |||
| idx = min(min_num, len(learnware_list)) | |||
| while idx < len(learnware_list): | |||
| if sorted_score_list[idx] < filter_score: | |||
| break | |||
| idx = idx + 1 | |||
| return sorted_score_list[:idx], learnware_list[:idx] | |||
| def _filter_by_rkme_spec_dimension( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||
| ) -> List[Learnware]: | |||
| """Filter learnwares whose rkme dimension different from user_rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMEStatSpecification | |||
| User RKME statistical specification | |||
| Returns | |||
| ------- | |||
| List[Learnware] | |||
| Learnwares whose rkme dimensions equal user_rkme in user_info | |||
| """ | |||
| filtered_learnware_list = [] | |||
| user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | |||
| for learnware in learnware_list: | |||
| rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") | |||
| rkme_dim = str(list(rkme.get_z().shape)[1:]) | |||
| if rkme_dim == user_rkme_dim: | |||
| filtered_learnware_list.append(learnware) | |||
| return filtered_learnware_list | |||
| def _search_by_rkme_spec_mixture_greedy( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMEStatSpecification, | |||
| max_search_num: int, | |||
| score_cutoff: float = 0.001, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| """Greedily match learnwares such that their mixture become closer and closer to user's rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMEStatSpecification | |||
| User RKME statistical specification | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| score_cutof: float | |||
| The minimum mmd dist as threshold to stop further rkme_spec matching | |||
| Returns | |||
| ------- | |||
| Tuple[float, List[float], List[Learnware]] | |||
| The first is the mixture mmd dist | |||
| The second is the list of weight | |||
| The third is the list of Learnware | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| if learnware_num == 0: | |||
| return None, [], [] | |||
| if learnware_num < max_search_num: | |||
| logger.warning("Available Learnware num less than search_num!") | |||
| max_search_num = learnware_num | |||
| flag_list = [0 for _ in range(learnware_num)] | |||
| mixture_list, mmd_dist = [], None | |||
| intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) | |||
| for k in range(max_search_num): | |||
| idx_min, score_min = -1, -1 | |||
| weight_min = None | |||
| mixture_list.append(None) | |||
| if k != 0: | |||
| intermediate_K = np.c_[intermediate_K, np.zeros((k, 1))] | |||
| intermediate_K = np.r_[intermediate_K, np.zeros((1, k + 1))] | |||
| intermediate_C = np.r_[intermediate_C, np.zeros((1, 1))] | |||
| for idx in range(len(learnware_list)): | |||
| if flag_list[idx] == 0: | |||
| mixture_list[-1] = learnware_list[idx] | |||
| intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| weight, score = self._calculate_rkme_spec_mixture_weight( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| if idx_min == -1 or score < score_min: | |||
| idx_min, score_min, weight_min = idx, score, weight | |||
| mmd_dist = score_min | |||
| mixture_list[-1] = learnware_list[idx_min] | |||
| if score_min < score_cutoff: | |||
| break | |||
| else: | |||
| flag_list[idx_min] = 1 | |||
| intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| return mmd_dist, weight_min, mixture_list | |||
| def _search_by_rkme_spec_single( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMEStatSpecification | |||
| user RKME statistical specification | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware]] | |||
| the first is the list of mmd dist | |||
| the second is the list of Learnware | |||
| both lists are sorted by mmd dist | |||
| """ | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||
| ] | |||
| mmd_dist_list = [] | |||
| for RKME in RKME_list: | |||
| mmd_dist = RKME.dist(user_rkme) | |||
| mmd_dist_list.append(mmd_dist) | |||
| sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) | |||
| sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list] | |||
| sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list] | |||
| return sorted_dist_list, sorted_learnware_list | |||
| def __call__( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_info: BaseUserInfo, | |||
| max_search_num: int = 5, | |||
| search_method: str = "greedy", | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| user_rkme = user_info.stat_info["RKMEStatSpecification"] | |||
| learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | |||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| if search_method == "auto": | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( | |||
| learnware_list, user_rkme, max_search_num | |||
| ) | |||
| elif search_method == "greedy": | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( | |||
| learnware_list, user_rkme, max_search_num | |||
| ) | |||
| else: | |||
| logger.warning("f{search_method} not supported!") | |||
| mixture_dist = None | |||
| weight_list = [] | |||
| mixture_learnware_list = [] | |||
| if mixture_dist is None: | |||
| sorted_score_list = self._convert_dist_to_score(sorted_dist_list) | |||
| mixture_score = None | |||
| else: | |||
| merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist]) | |||
| sorted_score_list = merge_score_list[:-1] | |||
| mixture_score = merge_score_list[-1] | |||
| logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| # filter learnware with low score | |||
| sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( | |||
| sorted_score_list, single_learnware_list | |||
| ) | |||
| logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list | |||
| class EasySearcher(BaseSearcher): | |||
| def __init__(self, organizer: EasyOrganizer = None): | |||
| super(EasySearcher, self).__init__(organizer) | |||
| self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) | |||
| self.table_searcher = EasyTableSearcher(organizer) | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| self.semantic_searcher.reset(organizer) | |||
| self.table_searcher.reset(organizer) | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| """Search learnwares based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec and stat_info | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware], float, List[Learnware]] | |||
| the first is the sorted list of rkme dist | |||
| the second is the sorted list of Learnware (single) by the rkme dist | |||
| the third is the score of Learnware (mixture) | |||
| the fourth is the list of Learnware (mixture), the size is search_num | |||
| """ | |||
| learnware_list = self.learnware_oganizer.get_learnwares() | |||
| learnware_list = self.semantic_searcher(learnware_list, user_info) | |||
| if len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| elif "RKMEStatSpecification" in user_info.stat_info: | |||
| return self.table_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| else: | |||
| return None, learnware_list, 0.0, None | |||
| @@ -0,0 +1 @@ | |||
| from .organizer import EvolvedOrganizer | |||
| @@ -1,21 +1,18 @@ | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from typing import List | |||
| from .base import BaseMarket | |||
| from ..learnware import Learnware | |||
| from ..specification import BaseStatSpecification | |||
| from ..easy2.organizer import EasyOrganizer | |||
| from ...learnware import Learnware | |||
| from ...specification import BaseStatSpecification | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("evolve_organizer") | |||
| class EvolvedMarket(BaseMarket): | |||
| """Organize learnwares and enable them to continuously evolve | |||
| Parameters | |||
| ---------- | |||
| BaseMarket : _type_ | |||
| Basic market version | |||
| """ | |||
| class EvolvedOrganizer(EasyOrganizer): | |||
| """Organize learnwares and enable them to continuously evolve""" | |||
| def __init__(self, *args, **kwargs): | |||
| super(EvolvedMarket, self).__init__(*args, **kwargs) | |||
| super(EvolvedOrganizer, self).__init__(*args, **kwargs) | |||
| def generate_new_stat_specification(self, learnware: Learnware) -> BaseStatSpecification: | |||
| """Generate new statistical specification for learnwares | |||
| @@ -0,0 +1 @@ | |||
| from .organizer import EvolvedAnchoredOrganizer | |||
| @@ -1,22 +1,17 @@ | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from typing import List | |||
| from .anchor import AnchoredUserInfo, AnchoredMarket | |||
| from .evolve import EvolvedMarket | |||
| from ..evolve import EvolvedOrganizer | |||
| from ..anchor import AnchoredOrganizer, AnchoredUserInfo | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("evolve_anchor_organizer") | |||
| class EvolvedAnchoredMarket(AnchoredMarket, EvolvedMarket): | |||
| """Organize learnwares with anchors and enable them to continuously evolve | |||
| Parameters | |||
| ---------- | |||
| AnchoredMarket : _type_ | |||
| Market version with anchors | |||
| EvolvedMarket : _type_ | |||
| Market version with evolved learnwares | |||
| """ | |||
| class EvolvedAnchoredOrganizer(AnchoredOrganizer, EvolvedOrganizer): | |||
| """Organize learnwares and enable them to continuously evolve""" | |||
| def __init__(self, *args, **kwargs): | |||
| super(EvolvedAnchoredMarket, self).__init__(*args, **kwargs) | |||
| AnchoredOrganizer.__init__(self, *args, **kwargs) | |||
| def evolve_anchor_learnware_list(self, anchor_id_list: List[str]): | |||
| """Enable anchor learnwares to evolve, e.g., new stat_spec | |||
| @@ -0,0 +1 @@ | |||
| from .organizer import MappingFunction, HeterogeneousOrganizer | |||
| @@ -1,8 +1,8 @@ | |||
| import numpy as np | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from typing import List | |||
| from .evolve import EvolvedMarket | |||
| from ..learnware import Learnware | |||
| from ..evolve.organizer import EvolvedOrganizer | |||
| from ...learnware import Learnware | |||
| class MappingFunction: | |||
| @@ -25,17 +25,11 @@ class MappingFunction: | |||
| pass | |||
| class HeterogeneousFeatureMarket(EvolvedMarket): | |||
| """Organize learnwares with heterogeneous feature spaces | |||
| Parameters | |||
| ---------- | |||
| EvolvedMarket : _type_ | |||
| Market version with evolved learnwares | |||
| """ | |||
| class HeterogeneousOrganizer(EvolvedOrganizer): | |||
| """Organize learnwares with heterogeneous feature spaces, organizer version with evolved learnwares""" | |||
| def __init__(self, *args, **kwargs): | |||
| super(HeterogeneousFeatureMarket, self).__init__(*args, **kwargs) | |||
| super(HeterogeneousOrganizer, self).__init__(*args, **kwargs) | |||
| self.mapping_function_list = {} | |||
| def _mapping_function_list_initialization(self, learnware_list: List[Learnware]): | |||
| @@ -0,0 +1,20 @@ | |||
| from .base import LearnwareMarket | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker | |||
| MARKET_CONFIG = { | |||
| "easy": { | |||
| "organizer": EasyOrganizer(), | |||
| "searcher": EasySearcher(), | |||
| "checker_list": [EasySemanticChecker(), EasyStatisticalChecker()], | |||
| } | |||
| } | |||
| def instatiate_learnware_market(market_id, name="easy", **kwargs): | |||
| return LearnwareMarket( | |||
| market_id=market_id, | |||
| organizer=MARKET_CONFIG[name]["organizer"], | |||
| searcher=MARKET_CONFIG[name]["searcher"], | |||
| checker_list=MARKET_CONFIG[name]["checker_list"], | |||
| **kwargs | |||
| ) | |||
| @@ -1,4 +1,4 @@ | |||
| from .utils import generate_stat_spec | |||
| from .base import Specification, BaseStatSpecification | |||
| from .rkme import RKMEStatSpecification | |||
| from .image import RKMEImageStatSpecification | |||
| from .table import RKMEStatSpecification | |||
| @@ -6,6 +6,15 @@ from typing import Dict | |||
| class BaseStatSpecification: | |||
| """The Statistical Specification Interface, which provide save and load method""" | |||
| def __init__(self, type: str): | |||
| """initilize the type of stats specification | |||
| Parameters | |||
| ---------- | |||
| type : str | |||
| the type of the stats specification | |||
| """ | |||
| self.type = type | |||
| def generate_stat_spec_from_data(self, **kwargs): | |||
| """Construct statistical specification from raw dataset | |||
| - kwargs may include the feature, label and model | |||
| @@ -0,0 +1 @@ | |||
| from .rkme import RKMEStatSpecification | |||
| @@ -20,8 +20,8 @@ try: | |||
| except ImportError: | |||
| _FAISS_INSTALLED = False | |||
| from .base import BaseStatSpecification | |||
| from ..logger import get_module_logger | |||
| from ..base import BaseStatSpecification | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("rkme") | |||
| @@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): | |||
| torch.cuda.empty_cache() | |||
| self.device = choose_device(cuda_idx=cuda_idx) | |||
| setup_seed(0) | |||
| super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) | |||
| def get_beta(self) -> np.ndarray: | |||
| """Move beta(RKME weights) back to memory accessible to the CPU. | |||
| @@ -427,6 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification): | |||
| rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() | |||
| rkme_to_save["beta"] = rkme_to_save["beta"].tolist() | |||
| rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" | |||
| rkme_to_save["type"] = self.type | |||
| json.dump( | |||
| rkme_to_save, | |||
| codecs.open(save_path, "w", encoding="utf-8"), | |||
| @@ -4,7 +4,7 @@ import pandas as pd | |||
| from typing import Union | |||
| from .base import BaseStatSpecification | |||
| from .rkme import RKMEStatSpecification | |||
| from .table import RKMEStatSpecification | |||
| from ..config import C | |||
| @@ -0,0 +1,10 @@ | |||
| ## How to Generate Environment Yaml | |||
| * create env config for conda: | |||
| ```shell | |||
| conda env export | grep -v "^prefix: " > environment.yml | |||
| ``` | |||
| * recover env from config | |||
| ``` | |||
| conda env create -f environment.yml | |||
| ``` | |||
| @@ -0,0 +1,27 @@ | |||
| name: learnware_example_env | |||
| channels: | |||
| - defaults | |||
| dependencies: | |||
| - _libgcc_mutex=0.1=main | |||
| - _openmp_mutex=5.1=1_gnu | |||
| - ca-certificates=2023.01.10=h06a4308_0 | |||
| - ld_impl_linux-64=2.38=h1181459_1 | |||
| - libffi=3.4.2=h6a678d5_6 | |||
| - libgcc-ng=11.2.0=h1234567_1 | |||
| - libgomp=11.2.0=h1234567_1 | |||
| - libstdcxx-ng=11.2.0=h1234567_1 | |||
| - ncurses=6.4=h6a678d5_0 | |||
| - openssl=1.1.1t=h7f8727e_0 | |||
| - pip=23.0.1=py38h06a4308_0 | |||
| - python=3.8.16=h7a1cb2a_3 | |||
| - readline=8.2=h5eee18b_0 | |||
| - setuptools=66.0.0=py38h06a4308_0 | |||
| - sqlite=3.41.2=h5eee18b_0 | |||
| - tk=8.6.12=h1ccaba5_0 | |||
| - wheel=0.38.4=py38h06a4308_0 | |||
| - xz=5.2.10=h5eee18b_1 | |||
| - zlib=1.2.13=h5eee18b_0 | |||
| - pip: | |||
| - joblib==1.2.0 | |||
| - learnware==0.0.1.99 | |||
| - numpy==1.19.5 | |||
| @@ -0,0 +1,8 @@ | |||
| model: | |||
| class_name: SVM | |||
| kwargs: {} | |||
| stat_specifications: | |||
| - module_path: learnware.specification | |||
| class_name: RKMEStatSpecification | |||
| file_name: svm.json | |||
| kwargs: {} | |||
| @@ -0,0 +1,20 @@ | |||
| import os | |||
| import joblib | |||
| import numpy as np | |||
| from learnware.model import BaseModel | |||
| class SVM(BaseModel): | |||
| def __init__(self): | |||
| super(SVM, self).__init__(input_shape=(64,), output_shape=(10,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict_proba(X) | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| @@ -0,0 +1,205 @@ | |||
| import sys | |||
| import unittest | |||
| import os | |||
| import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| from sklearn import svm | |||
| from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import instatiate_learnware_market, BaseUserInfo | |||
| import learnware.specification as specification | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Image"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| }, | |||
| "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Education"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| "Output": { | |||
| "Dimension": 10, | |||
| "Description": { | |||
| "0": "the probability of the label is zero", | |||
| }, | |||
| }, | |||
| } | |||
| class TestMarket(unittest.TestCase): | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| np.random.seed(2023) | |||
| learnware.init() | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) | |||
| return easy_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||
| self.zip_path_list = [] | |||
| X, y = load_digits(return_X_y=True) | |||
| for i in range(learnware_num): | |||
| dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i)) | |||
| os.makedirs(dir_path, exist_ok=True) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| clf = svm.SVC(kernel="linear", probability=True) | |||
| clf.fit(data_X, data_y) | |||
| joblib.dump(clf, os.path.join(dir_path, "svm.pkl")) | |||
| spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec.save(os.path.join(dir_path, "svm.json")) | |||
| init_file = os.path.join(dir_path, "__init__.py") | |||
| copyfile( | |||
| os.path.join(curr_root, "learnware_example/example_init.py"), init_file | |||
| ) # cp example_init.py init_file | |||
| yaml_file = os.path.join(dir_path, "learnware.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file | |||
| env_file = os.path.join(dir_path, "environment.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file) | |||
| zip_file = dir_path + ".zip" | |||
| # zip -q -r -j zip_file dir_path | |||
| with zipfile.ZipFile(zip_file, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(dir_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| rmtree(dir_path) # rm -r dir_path | |||
| self.zip_path_list.append(zip_file) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| easy_market = self._init_learnware_market() | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| easy_market.delete_learnware(learnware_id) | |||
| self.learnware_num -= 1 | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return easy_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == 1, f"Exact semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" | |||
| semantic_spec["Name"]["Values"] = "laernwaer" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| def test_stat_search(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| assert len(single_learnware_list) == self.learnware_num, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -0,0 +1,31 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import tempfile | |||
| import numpy as np | |||
| import learnware | |||
| import learnware.specification as specification | |||
| from learnware.specification import RKMEStatSpecification | |||
| class TestRKME(unittest.TestCase): | |||
| def test_rkme(self): | |||
| X = np.random.uniform(-10000, 10000, size=(5000, 200)) | |||
| rkme = specification.utils.generate_rkme_spec(X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMEStatSpecification" | |||
| rkme2 = RKMEStatSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMEStatSpecification" | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -155,7 +155,7 @@ class TestAllWorkflow(unittest.TestCase): | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = specification.rkme.RKMEStatSpecification() | |||
| user_spec = specification.RKMEStatSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||
| ( | |||