diff --git a/examples/dataset_text_workflow/main.py b/examples/dataset_text_workflow/main.py index 6c712b7..e7e1c38 100644 --- a/examples/dataset_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -9,7 +9,7 @@ from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningR import time import pickle -from learnware.market import instatiate_learnware_market, BaseUserInfo +from learnware.market import instantiate_learnware_market, BaseUserInfo from learnware.market import database_ops from learnware.learnware import Learnware import learnware.specification as specification @@ -120,7 +120,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo def prepare_market(): - text_market = instatiate_learnware_market(market_id="sst2", rebuild=True) + text_market = instantiate_learnware_market(market_id="sst2", rebuild=True) try: rmtree(learnware_pool_dir) except: @@ -144,10 +144,10 @@ def prepare_market(): def test_search(gamma=0.1, load_market=True): if load_market: - text_market = instatiate_learnware_market(market_id="sst2") + text_market = instantiate_learnware_market(market_id="sst2") else: prepare_market() - text_market = instatiate_learnware_market(market_id="sst2") + text_market = instantiate_learnware_market(market_id="sst2") logger.info("Number of items in the market: %d" % len(text_market)) select_list = [] diff --git a/learnware/client/container.py b/learnware/client/container.py index 243d580..a07fd7d 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -337,6 +337,7 @@ class ModelDockerContainer(ModelContainer): "install", "-r", f"{requirements_path_filter}", + "--no-dependencies", ] ) ) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 90c0e35..ceacdda 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -92,7 +92,7 @@ class LearnwareClient: @require_login def upload_learnware(self, learnware_zip_path, semantic_specification): - assert self._check_semantic_specification(semantic_specification) + assert self._check_semantic_specification(semantic_specification), "Semantic specification check failed!" file_hash = compute_file_hash(learnware_zip_path) url_upload = f"{self.host}/user/chunked_upload" @@ -276,8 +276,7 @@ class LearnwareClient: response = requests.get(url, headers=self.headers) result = response.json() semantic_conf = result["data"]["semantic_specification"] - - return semantic_conf[key]["Values"] + return semantic_conf[key.value]["Values"] def load_learnware( self, @@ -386,14 +385,16 @@ class LearnwareClient: @staticmethod def _check_semantic_specification(semantic_spec): - return EasySemanticChecker.check_semantic_spec(semantic_spec) + return EasySemanticChecker.check_semantic_spec(semantic_spec) != EasySemanticChecker.INVALID_LEARNWARE @staticmethod def check_learnware(learnware_zip_path, semantic_specification=None): semantic_specification = ( get_semantic_specification() if semantic_specification is None else semantic_specification ) - LearnwareClient._check_semantic_specification(semantic_specification) + assert LearnwareClient._check_semantic_specification( + semantic_specification + ), "Semantic specification check failed!" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: diff --git a/learnware/client/utils.py b/learnware/client/utils.py index 95a10c7..21bb12c 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -76,6 +76,7 @@ def install_environment(zip_path, conda_env): "install", "-r", f"{requirements_path_filter}", + "--no-dependencies", ] ) else: diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index 7dd1bc6..649baa5 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -6,4 +6,4 @@ from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatist from .hetergeneous import HeterogeneousOrganizer, MappingFunction from .easy import EasyMarket -from .module import instatiate_learnware_market +from .module import instantiate_learnware_market diff --git a/learnware/market/base.py b/learnware/market/base.py index 927d4fe..528cd46 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -1,7 +1,7 @@ +from __future__ import annotations + import zipfile import tempfile - - from typing import Tuple, Any, List, Union from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger @@ -47,10 +47,10 @@ class LearnwareMarket: def __init__( self, - market_id: str = None, - organizer: "BaseOrganizer" = None, - searcher: "BaseSearcher" = None, - checker_list: List["BaseChecker"] = None, + market_id: str = "default", + organizer: BaseOrganizer = None, + searcher: BaseSearcher = None, + checker_list: List[BaseChecker] = None, rebuild=False, ): self.market_id = market_id @@ -70,29 +70,27 @@ class LearnwareMarket: def reload_market(self, **kwargs) -> bool: self.learnware_organizer.reload_market(**kwargs) - def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: + def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: try: - with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: - with zipfile.ZipFile(zip_path, mode="r") as z_file: - z_file.extractall(tempdir) - - pending_learnware = get_learnware_from_dirpath( - id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir - ) - - final_status = BaseChecker.INVALID_LEARNWARE - checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names - - for name in checker_names: - checker = self.learnware_checker[name] - check_status = checker(pending_learnware) - final_status = max(final_status, check_status) - - if check_status == BaseChecker.INVALID_LEARNWARE: - return BaseChecker.INVALID_LEARNWARE - - return final_status - + final_status = BaseChecker.NONUSABLE_LEARNWARE + if len(checker_names): + with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: + with zipfile.ZipFile(zip_path, mode="r") as z_file: + z_file.extractall(tempdir) + + pending_learnware = get_learnware_from_dirpath( + id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir + ) + checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names + + for name in checker_names: + checker = self.learnware_checker[name] + check_status = checker(pending_learnware) + final_status = max(final_status, check_status) + + if check_status == BaseChecker.INVALID_LEARNWARE: + return BaseChecker.INVALID_LEARNWARE + return final_status except Exception as err: logger.warning(f"Check learnware failed! Due to {err}.") return BaseChecker.INVALID_LEARNWARE @@ -122,8 +120,25 @@ class LearnwareMarket: zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: - return self.learnware_searcher(user_info, **kwargs) + def search_learnware( + self, user_info: BaseUserInfo, check_status: int = None, **kwargs + ) -> Tuple[Any, List[Learnware]]: + """Search learnwares based on user_info from learnwares with check_status + + Parameters + ---------- + user_info : BaseUserInfo + User information for searching learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status + + Returns + ------- + Tuple[Any, List[Learnware]] + Search results + """ + return self.learnware_searcher(user_info, check_status, **kwargs) def delete_learnware(self, id: str, **kwargs) -> bool: return self.learnware_organizer.delete_learnware(id, **kwargs) @@ -131,8 +146,8 @@ class LearnwareMarket: def update_learnware( self, id: str, - zip_path: str, - semantic_spec: dict, + zip_path: str = None, + semantic_spec: dict = None, checker_names: List[str] = None, check_status: int = None, **kwargs, @@ -157,6 +172,12 @@ class LearnwareMarket: int The final learnware check_status. """ + zip_path = self.get_learnware_path_by_ids(id) if zip_path is None else zip_path + semantic_spec = ( + self.get_learnware_by_ids(id).get_specification().get_semantic_spec() + if semantic_spec is None + else semantic_spec + ) update_status = self.check_learnware(zip_path, semantic_spec, checker_names) check_status = ( update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status @@ -166,14 +187,44 @@ class LearnwareMarket: id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def get_learnware_ids(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnware_ids(top, **kwargs) + def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]: + """get the list of learnware ids + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Raises + ------ + List[str] + the first top ids + """ + return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs) - def get_learnwares(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnwares(top, **kwargs) + def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]: + """get the list of learnwares + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status + + Raises + ------ + List[Learnware] + the first top learnwares + """ + return self.learnware_organizer.get_learnwares(top, check_status, **kwargs) def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: - raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) + return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: return self.learnware_organizer.get_learnware_by_ids(id, **kwargs) @@ -298,13 +349,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer") - def get_learnware_ids(self, top: int = None) -> List[str]: + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: """get the list of learnware ids Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status Raises ------ @@ -313,13 +367,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer") - def get_learnwares(self, top: int = None) -> List[Learnware]: + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: """get the list of learnwares Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status Raises ------ @@ -334,18 +391,21 @@ class BaseOrganizer: class BaseSearcher: def __init__(self, organizer: BaseOrganizer = None): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer - def __call__(self, user_info: BaseUserInfo): - """Search learnwares based on user_info + def __call__(self, user_info: BaseUserInfo, check_status: int = None): + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- user_info : BaseUserInfo user_info contains semantic_spec and stat_info + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status """ raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") @@ -356,10 +416,10 @@ class BaseChecker: USABLE_LEARWARE = 1 def __init__(self, organizer: BaseOrganizer = None): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def __call__(self, learnware: Learnware) -> int: """Check the utility of a learnware diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 18f67eb..cef55fc 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -55,7 +55,9 @@ class EasyOrganizer(BaseOrganizer): self.count, ) = self.dbops.load_market() - def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]: + def add_learnware( + self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None + ) -> Tuple[str, int]: """Add a learnware into the market. Parameters @@ -66,7 +68,8 @@ class EasyOrganizer(BaseOrganizer): semantic_spec for new learnware, in dictionary format. check_status: int A flag indicating whether the learnware is usable. - + learnware_id: int + A id in database for learnware Returns ------- Tuple[str, int] @@ -80,9 +83,9 @@ class EasyOrganizer(BaseOrganizer): semantic_spec = copy.deepcopy(semantic_spec) logger.info("Get new learnware from %s" % (zip_path)) - id = "%08d" % (self.count) - target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) - target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) + learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id + target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) + target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) copyfile(zip_path, target_zip_dir) with zipfile.ZipFile(target_zip_dir, "r") as z_file: @@ -91,7 +94,7 @@ class EasyOrganizer(BaseOrganizer): try: new_learnware = get_learnware_from_dirpath( - id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir + id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir ) except: try: @@ -107,19 +110,19 @@ class EasyOrganizer(BaseOrganizer): learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE self.dbops.add_learnware( - id=id, + id=learnware_id, semantic_spec=semantic_spec, zip_path=target_zip_dir, folder_path=target_folder_dir, use_flag=learnwere_status, ) - self.learnware_list[id] = new_learnware - self.learnware_zip_list[id] = target_zip_dir - self.learnware_folder_list[id] = target_folder_dir - self.use_flags[id] = learnwere_status + self.learnware_list[learnware_id] = new_learnware + self.learnware_zip_list[learnware_id] = target_zip_dir + self.learnware_folder_list[learnware_id] = target_folder_dir + self.use_flags[learnware_id] = learnwere_status self.count += 1 - return id, learnwere_status + return learnware_id, learnwere_status def delete_learnware(self, id: str) -> bool: """Delete Learnware from market @@ -205,7 +208,8 @@ class EasyOrganizer(BaseOrganizer): if new_learnware is None: return BaseChecker.INVALID_LEARNWARE - copyfile(zip_path, target_zip_dir) + if zip_path != target_zip_dir: + copyfile(zip_path, target_zip_dir) with zipfile.ZipFile(target_zip_dir, "r") as z_file: z_file.extractall(target_folder_dir) @@ -284,17 +288,52 @@ class EasyOrganizer(BaseOrganizer): logger.warning("Learnware ID '%s' NOT Found!" % (ids)) return None - def get_learnware_ids(self, top: int = None) -> List[str]: - if top is None: - return list(self.learnware_list.keys()) - else: - return list(self.learnware_list.keys())[:top] + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: + """Get learnware ids + + Parameters + ---------- + top : int, optional + The first top learnware ids to return, by default None + check_status : bool, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Returns + ------- + List[str] + Learnware ids + """ + if check_status is None: + filtered_ids = self.use_flags.keys() + elif check_status is True: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE] + elif check_status is False: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE] - def get_learnwares(self, top: int = None) -> List[str]: if top is None: - return list(self.learnware_list.values()) + return filtered_ids else: - return list(self.learnware_list.values())[:top] + return filtered_ids[:top] + + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: + """Get learnware list + + Parameters + ---------- + top : int, optional + The first top learnwares to return, by default None + check_status : bool, optional + - None: return all learnwares + - Others: return learnwares with check_status + + Returns + ------- + List[Learnware] + Learnware list + """ + learnware_ids = self.get_learnware_ids(top, check_status) + return [self.learnware_list[idx] for idx in learnware_ids] def __len__(self): return len(self.learnware_list) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index e07f861..8feb5d9 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -613,14 +613,14 @@ class EasySearcher(BaseSearcher): self.stat_searcher = EasyStatSearcher(organizer) def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer self.semantic_searcher.reset(organizer) self.stat_searcher.reset(organizer) def __call__( - self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" + self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - """Search learnwares based on user_info + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- @@ -628,6 +628,9 @@ class EasySearcher(BaseSearcher): user_info contains semantic_spec and stat_info max_search_num : int The maximum number of the returned learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status Returns ------- @@ -637,7 +640,7 @@ class EasySearcher(BaseSearcher): the third is the score of Learnware (mixture) the fourth is the list of Learnware (mixture), the size is search_num """ - learnware_list = self.learnware_oganizer.get_learnwares() + learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) learnware_list = self.semantic_searcher(learnware_list, user_info) if len(learnware_list) == 0: diff --git a/learnware/market/module.py b/learnware/market/module.py index 57821cd..3b55b52 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -10,7 +10,7 @@ MARKET_CONFIG = { } -def instatiate_learnware_market(market_id, name="easy", **kwargs): +def instantiate_learnware_market(market_id="default", name="easy", **kwargs): return LearnwareMarket( market_id=market_id, organizer=MARKET_CONFIG[name]["organizer"], diff --git a/tests/test_market/test_easy.py b/tests/test_market/test_easy.py index 16729e2..bb03839 100644 --- a/tests/test_market/test_easy.py +++ b/tests/test_market/test_easy.py @@ -11,7 +11,7 @@ from sklearn.model_selection import train_test_split from shutil import copyfile, rmtree import learnware -from learnware.market import instatiate_learnware_market, BaseUserInfo +from learnware.market import instantiate_learnware_market, BaseUserInfo import learnware.specification as specification curr_root = os.path.dirname(os.path.abspath(__file__)) @@ -43,7 +43,7 @@ class TestMarket(unittest.TestCase): def _init_learnware_market(self): """initialize learnware market""" - easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) + easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) return easy_market def test_prepare_learnware_randomly(self, learnware_num=5):