| @@ -9,7 +9,7 @@ from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningR | |||
| import time | |||
| import pickle | |||
| from learnware.market import instatiate_learnware_market, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import database_ops | |||
| from learnware.learnware import Learnware | |||
| import learnware.specification as specification | |||
| @@ -120,7 +120,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo | |||
| def prepare_market(): | |||
| text_market = instatiate_learnware_market(market_id="sst2", rebuild=True) | |||
| text_market = instantiate_learnware_market(market_id="sst2", rebuild=True) | |||
| try: | |||
| rmtree(learnware_pool_dir) | |||
| except: | |||
| @@ -144,10 +144,10 @@ def prepare_market(): | |||
| def test_search(gamma=0.1, load_market=True): | |||
| if load_market: | |||
| text_market = instatiate_learnware_market(market_id="sst2") | |||
| text_market = instantiate_learnware_market(market_id="sst2") | |||
| else: | |||
| prepare_market() | |||
| text_market = instatiate_learnware_market(market_id="sst2") | |||
| text_market = instantiate_learnware_market(market_id="sst2") | |||
| logger.info("Number of items in the market: %d" % len(text_market)) | |||
| select_list = [] | |||
| @@ -337,6 +337,7 @@ class ModelDockerContainer(ModelContainer): | |||
| "install", | |||
| "-r", | |||
| f"{requirements_path_filter}", | |||
| "--no-dependencies", | |||
| ] | |||
| ) | |||
| ) | |||
| @@ -92,7 +92,7 @@ class LearnwareClient: | |||
| @require_login | |||
| def upload_learnware(self, learnware_zip_path, semantic_specification): | |||
| assert self._check_semantic_specification(semantic_specification) | |||
| assert self._check_semantic_specification(semantic_specification), "Semantic specification check failed!" | |||
| file_hash = compute_file_hash(learnware_zip_path) | |||
| url_upload = f"{self.host}/user/chunked_upload" | |||
| @@ -276,8 +276,7 @@ class LearnwareClient: | |||
| response = requests.get(url, headers=self.headers) | |||
| result = response.json() | |||
| semantic_conf = result["data"]["semantic_specification"] | |||
| return semantic_conf[key]["Values"] | |||
| return semantic_conf[key.value]["Values"] | |||
| def load_learnware( | |||
| self, | |||
| @@ -386,14 +385,16 @@ class LearnwareClient: | |||
| @staticmethod | |||
| def _check_semantic_specification(semantic_spec): | |||
| return EasySemanticChecker.check_semantic_spec(semantic_spec) | |||
| return EasySemanticChecker.check_semantic_spec(semantic_spec) != EasySemanticChecker.INVALID_LEARNWARE | |||
| @staticmethod | |||
| def check_learnware(learnware_zip_path, semantic_specification=None): | |||
| semantic_specification = ( | |||
| get_semantic_specification() if semantic_specification is None else semantic_specification | |||
| ) | |||
| LearnwareClient._check_semantic_specification(semantic_specification) | |||
| assert LearnwareClient._check_semantic_specification( | |||
| semantic_specification | |||
| ), "Semantic specification check failed!" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: | |||
| @@ -76,6 +76,7 @@ def install_environment(zip_path, conda_env): | |||
| "install", | |||
| "-r", | |||
| f"{requirements_path_filter}", | |||
| "--no-dependencies", | |||
| ] | |||
| ) | |||
| else: | |||
| @@ -6,4 +6,4 @@ from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatist | |||
| from .hetergeneous import HeterogeneousOrganizer, MappingFunction | |||
| from .easy import EasyMarket | |||
| from .module import instatiate_learnware_market | |||
| from .module import instantiate_learnware_market | |||
| @@ -1,7 +1,7 @@ | |||
| from __future__ import annotations | |||
| import zipfile | |||
| import tempfile | |||
| from typing import Tuple, Any, List, Union | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| @@ -47,10 +47,10 @@ class LearnwareMarket: | |||
| def __init__( | |||
| self, | |||
| market_id: str = None, | |||
| organizer: "BaseOrganizer" = None, | |||
| searcher: "BaseSearcher" = None, | |||
| checker_list: List["BaseChecker"] = None, | |||
| market_id: str = "default", | |||
| organizer: BaseOrganizer = None, | |||
| searcher: BaseSearcher = None, | |||
| checker_list: List[BaseChecker] = None, | |||
| rebuild=False, | |||
| ): | |||
| self.market_id = market_id | |||
| @@ -70,29 +70,27 @@ class LearnwareMarket: | |||
| def reload_market(self, **kwargs) -> bool: | |||
| self.learnware_organizer.reload_market(**kwargs) | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| try: | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| pending_learnware = get_learnware_from_dirpath( | |||
| id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir | |||
| ) | |||
| final_status = BaseChecker.INVALID_LEARNWARE | |||
| checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names | |||
| for name in checker_names: | |||
| checker = self.learnware_checker[name] | |||
| check_status = checker(pending_learnware) | |||
| final_status = max(final_status, check_status) | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| return final_status | |||
| final_status = BaseChecker.NONUSABLE_LEARNWARE | |||
| if len(checker_names): | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| pending_learnware = get_learnware_from_dirpath( | |||
| id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir | |||
| ) | |||
| checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names | |||
| for name in checker_names: | |||
| checker = self.learnware_checker[name] | |||
| check_status = checker(pending_learnware) | |||
| final_status = max(final_status, check_status) | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| return final_status | |||
| except Exception as err: | |||
| logger.warning(f"Check learnware failed! Due to {err}.") | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| @@ -122,8 +120,25 @@ class LearnwareMarket: | |||
| zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs | |||
| ) | |||
| def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: | |||
| return self.learnware_searcher(user_info, **kwargs) | |||
| def search_learnware( | |||
| self, user_info: BaseUserInfo, check_status: int = None, **kwargs | |||
| ) -> Tuple[Any, List[Learnware]]: | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| ---------- | |||
| user_info : BaseUserInfo | |||
| User information for searching learnwares | |||
| check_status : int, optional | |||
| - None: search from all learnwares | |||
| - Others: search from learnwares with check_status | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Learnware]] | |||
| Search results | |||
| """ | |||
| return self.learnware_searcher(user_info, check_status, **kwargs) | |||
| def delete_learnware(self, id: str, **kwargs) -> bool: | |||
| return self.learnware_organizer.delete_learnware(id, **kwargs) | |||
| @@ -131,8 +146,8 @@ class LearnwareMarket: | |||
| def update_learnware( | |||
| self, | |||
| id: str, | |||
| zip_path: str, | |||
| semantic_spec: dict, | |||
| zip_path: str = None, | |||
| semantic_spec: dict = None, | |||
| checker_names: List[str] = None, | |||
| check_status: int = None, | |||
| **kwargs, | |||
| @@ -157,6 +172,12 @@ class LearnwareMarket: | |||
| int | |||
| The final learnware check_status. | |||
| """ | |||
| zip_path = self.get_learnware_path_by_ids(id) if zip_path is None else zip_path | |||
| semantic_spec = ( | |||
| self.get_learnware_by_ids(id).get_specification().get_semantic_spec() | |||
| if semantic_spec is None | |||
| else semantic_spec | |||
| ) | |||
| update_status = self.check_learnware(zip_path, semantic_spec, checker_names) | |||
| check_status = ( | |||
| update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status | |||
| @@ -166,14 +187,44 @@ class LearnwareMarket: | |||
| id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs | |||
| ) | |||
| def get_learnware_ids(self, top: int = None, **kwargs): | |||
| return self.learnware_organizer.get_learnware_ids(top, **kwargs) | |||
| def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]: | |||
| """get the list of learnware ids | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| The first top element to return, by default None | |||
| check_status : int, optional | |||
| - None: return all learnware ids | |||
| - Others: return learnware ids with check_status | |||
| Raises | |||
| ------ | |||
| List[str] | |||
| the first top ids | |||
| """ | |||
| return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs) | |||
| def get_learnwares(self, top: int = None, **kwargs): | |||
| return self.learnware_organizer.get_learnwares(top, **kwargs) | |||
| def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]: | |||
| """get the list of learnwares | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| The first top element to return, by default None | |||
| check_status : int, optional | |||
| - None: return all learnwares | |||
| - Others: return learnwares with check_status | |||
| Raises | |||
| ------ | |||
| List[Learnware] | |||
| the first top learnwares | |||
| """ | |||
| return self.learnware_organizer.get_learnwares(top, check_status, **kwargs) | |||
| def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) | |||
| return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) | |||
| def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| return self.learnware_organizer.get_learnware_by_ids(id, **kwargs) | |||
| @@ -298,13 +349,16 @@ class BaseOrganizer: | |||
| """ | |||
| raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer") | |||
| def get_learnware_ids(self, top: int = None) -> List[str]: | |||
| def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: | |||
| """get the list of learnware ids | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| the first top element to return, by default None | |||
| The first top element to return, by default None | |||
| check_status : int, optional | |||
| - None: return all learnware ids | |||
| - Others: return learnware ids with check_status | |||
| Raises | |||
| ------ | |||
| @@ -313,13 +367,16 @@ class BaseOrganizer: | |||
| """ | |||
| raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer") | |||
| def get_learnwares(self, top: int = None) -> List[Learnware]: | |||
| def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: | |||
| """get the list of learnwares | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| the first top element to return, by default None | |||
| The first top element to return, by default None | |||
| check_status : int, optional | |||
| - None: return all learnwares | |||
| - Others: return learnwares with check_status | |||
| Raises | |||
| ------ | |||
| @@ -334,18 +391,21 @@ class BaseOrganizer: | |||
| class BaseSearcher: | |||
| def __init__(self, organizer: BaseOrganizer = None): | |||
| self.learnware_oganizer = organizer | |||
| self.learnware_organizer = organizer | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| self.learnware_organizer = organizer | |||
| def __call__(self, user_info: BaseUserInfo): | |||
| """Search learnwares based on user_info | |||
| def __call__(self, user_info: BaseUserInfo, check_status: int = None): | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| ---------- | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec and stat_info | |||
| check_status : int, optional | |||
| - None: search from all learnwares | |||
| - Others: search from learnwares with check_status | |||
| """ | |||
| raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") | |||
| @@ -356,10 +416,10 @@ class BaseChecker: | |||
| USABLE_LEARWARE = 1 | |||
| def __init__(self, organizer: BaseOrganizer = None): | |||
| self.learnware_oganizer = organizer | |||
| self.learnware_organizer = organizer | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| self.learnware_organizer = organizer | |||
| def __call__(self, learnware: Learnware) -> int: | |||
| """Check the utility of a learnware | |||
| @@ -55,7 +55,9 @@ class EasyOrganizer(BaseOrganizer): | |||
| self.count, | |||
| ) = self.dbops.load_market() | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]: | |||
| def add_learnware( | |||
| self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None | |||
| ) -> Tuple[str, int]: | |||
| """Add a learnware into the market. | |||
| Parameters | |||
| @@ -66,7 +68,8 @@ class EasyOrganizer(BaseOrganizer): | |||
| semantic_spec for new learnware, in dictionary format. | |||
| check_status: int | |||
| A flag indicating whether the learnware is usable. | |||
| learnware_id: int | |||
| A id in database for learnware | |||
| Returns | |||
| ------- | |||
| Tuple[str, int] | |||
| @@ -80,9 +83,9 @@ class EasyOrganizer(BaseOrganizer): | |||
| semantic_spec = copy.deepcopy(semantic_spec) | |||
| logger.info("Get new learnware from %s" % (zip_path)) | |||
| id = "%08d" % (self.count) | |||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) | |||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) | |||
| learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id | |||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) | |||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) | |||
| copyfile(zip_path, target_zip_dir) | |||
| with zipfile.ZipFile(target_zip_dir, "r") as z_file: | |||
| @@ -91,7 +94,7 @@ class EasyOrganizer(BaseOrganizer): | |||
| try: | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| ) | |||
| except: | |||
| try: | |||
| @@ -107,19 +110,19 @@ class EasyOrganizer(BaseOrganizer): | |||
| learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE | |||
| self.dbops.add_learnware( | |||
| id=id, | |||
| id=learnware_id, | |||
| semantic_spec=semantic_spec, | |||
| zip_path=target_zip_dir, | |||
| folder_path=target_folder_dir, | |||
| use_flag=learnwere_status, | |||
| ) | |||
| self.learnware_list[id] = new_learnware | |||
| self.learnware_zip_list[id] = target_zip_dir | |||
| self.learnware_folder_list[id] = target_folder_dir | |||
| self.use_flags[id] = learnwere_status | |||
| self.learnware_list[learnware_id] = new_learnware | |||
| self.learnware_zip_list[learnware_id] = target_zip_dir | |||
| self.learnware_folder_list[learnware_id] = target_folder_dir | |||
| self.use_flags[learnware_id] = learnwere_status | |||
| self.count += 1 | |||
| return id, learnwere_status | |||
| return learnware_id, learnwere_status | |||
| def delete_learnware(self, id: str) -> bool: | |||
| """Delete Learnware from market | |||
| @@ -205,7 +208,8 @@ class EasyOrganizer(BaseOrganizer): | |||
| if new_learnware is None: | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| copyfile(zip_path, target_zip_dir) | |||
| if zip_path != target_zip_dir: | |||
| copyfile(zip_path, target_zip_dir) | |||
| with zipfile.ZipFile(target_zip_dir, "r") as z_file: | |||
| z_file.extractall(target_folder_dir) | |||
| @@ -284,17 +288,52 @@ class EasyOrganizer(BaseOrganizer): | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def get_learnware_ids(self, top: int = None) -> List[str]: | |||
| if top is None: | |||
| return list(self.learnware_list.keys()) | |||
| else: | |||
| return list(self.learnware_list.keys())[:top] | |||
| def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: | |||
| """Get learnware ids | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| The first top learnware ids to return, by default None | |||
| check_status : bool, optional | |||
| - None: return all learnware ids | |||
| - Others: return learnware ids with check_status | |||
| Returns | |||
| ------- | |||
| List[str] | |||
| Learnware ids | |||
| """ | |||
| if check_status is None: | |||
| filtered_ids = self.use_flags.keys() | |||
| elif check_status is True: | |||
| filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE] | |||
| elif check_status is False: | |||
| filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE] | |||
| def get_learnwares(self, top: int = None) -> List[str]: | |||
| if top is None: | |||
| return list(self.learnware_list.values()) | |||
| return filtered_ids | |||
| else: | |||
| return list(self.learnware_list.values())[:top] | |||
| return filtered_ids[:top] | |||
| def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: | |||
| """Get learnware list | |||
| Parameters | |||
| ---------- | |||
| top : int, optional | |||
| The first top learnwares to return, by default None | |||
| check_status : bool, optional | |||
| - None: return all learnwares | |||
| - Others: return learnwares with check_status | |||
| Returns | |||
| ------- | |||
| List[Learnware] | |||
| Learnware list | |||
| """ | |||
| learnware_ids = self.get_learnware_ids(top, check_status) | |||
| return [self.learnware_list[idx] for idx in learnware_ids] | |||
| def __len__(self): | |||
| return len(self.learnware_list) | |||
| @@ -613,14 +613,14 @@ class EasySearcher(BaseSearcher): | |||
| self.stat_searcher = EasyStatSearcher(organizer) | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| self.learnware_organizer = organizer | |||
| self.semantic_searcher.reset(organizer) | |||
| self.stat_searcher.reset(organizer) | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" | |||
| self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| """Search learnwares based on user_info | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| ---------- | |||
| @@ -628,6 +628,9 @@ class EasySearcher(BaseSearcher): | |||
| user_info contains semantic_spec and stat_info | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| check_status : int, optional | |||
| - None: search from all learnwares | |||
| - Others: search from learnwares with check_status | |||
| Returns | |||
| ------- | |||
| @@ -637,7 +640,7 @@ class EasySearcher(BaseSearcher): | |||
| the third is the score of Learnware (mixture) | |||
| the fourth is the list of Learnware (mixture), the size is search_num | |||
| """ | |||
| learnware_list = self.learnware_oganizer.get_learnwares() | |||
| learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) | |||
| learnware_list = self.semantic_searcher(learnware_list, user_info) | |||
| if len(learnware_list) == 0: | |||
| @@ -10,7 +10,7 @@ MARKET_CONFIG = { | |||
| } | |||
| def instatiate_learnware_market(market_id, name="easy", **kwargs): | |||
| def instantiate_learnware_market(market_id="default", name="easy", **kwargs): | |||
| return LearnwareMarket( | |||
| market_id=market_id, | |||
| organizer=MARKET_CONFIG[name]["organizer"], | |||
| @@ -11,7 +11,7 @@ from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import instatiate_learnware_market, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| import learnware.specification as specification | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| @@ -43,7 +43,7 @@ class TestMarket(unittest.TestCase): | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) | |||
| return easy_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||