From d8208d98c666aba21321e26092331d351d4d37e4 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 2 Nov 2023 01:45:47 +0800 Subject: [PATCH 01/15] [MNT] add default market id --- learnware/market/base.py | 12 ++++++------ learnware/market/module.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 927d4fe..ef9d46a 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -1,7 +1,7 @@ +from __future__ import annotations + import zipfile import tempfile - - from typing import Tuple, Any, List, Union from ..learnware import Learnware, get_learnware_from_dirpath from ..logger import get_module_logger @@ -47,10 +47,10 @@ class LearnwareMarket: def __init__( self, - market_id: str = None, - organizer: "BaseOrganizer" = None, - searcher: "BaseSearcher" = None, - checker_list: List["BaseChecker"] = None, + market_id: str = "default", + organizer: BaseOrganizer = None, + searcher: BaseSearcher = None, + checker_list: List[BaseChecker] = None, rebuild=False, ): self.market_id = market_id diff --git a/learnware/market/module.py b/learnware/market/module.py index 57821cd..6ad363b 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -10,7 +10,7 @@ MARKET_CONFIG = { } -def instatiate_learnware_market(market_id, name="easy", **kwargs): +def instatiate_learnware_market(market_id="default", name="easy", **kwargs): return LearnwareMarket( market_id=market_id, organizer=MARKET_CONFIG[name]["organizer"], From fd5a8b3442615a6401c1486d8b70eaf7eef60c08 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 2 Nov 2023 10:10:48 +0800 Subject: [PATCH 02/15] [MNT] add learnware_id in EasyOrganizer --- learnware/market/easy2/organizer.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 18f67eb..1fbb9f6 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -55,7 +55,9 @@ class EasyOrganizer(BaseOrganizer): self.count, ) = self.dbops.load_market() - def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]: + def add_learnware( + self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None + ) -> Tuple[str, int]: """Add a learnware into the market. Parameters @@ -66,7 +68,8 @@ class EasyOrganizer(BaseOrganizer): semantic_spec for new learnware, in dictionary format. check_status: int A flag indicating whether the learnware is usable. - + learnware_id: int + A id in database for learnware Returns ------- Tuple[str, int] @@ -80,9 +83,9 @@ class EasyOrganizer(BaseOrganizer): semantic_spec = copy.deepcopy(semantic_spec) logger.info("Get new learnware from %s" % (zip_path)) - id = "%08d" % (self.count) - target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) - target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) + learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id + target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) + target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) copyfile(zip_path, target_zip_dir) with zipfile.ZipFile(target_zip_dir, "r") as z_file: @@ -91,7 +94,7 @@ class EasyOrganizer(BaseOrganizer): try: new_learnware = get_learnware_from_dirpath( - id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir + id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir ) except: try: @@ -107,19 +110,19 @@ class EasyOrganizer(BaseOrganizer): learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE self.dbops.add_learnware( - id=id, + id=learnware_id, semantic_spec=semantic_spec, zip_path=target_zip_dir, folder_path=target_folder_dir, use_flag=learnwere_status, ) - self.learnware_list[id] = new_learnware - self.learnware_zip_list[id] = target_zip_dir - self.learnware_folder_list[id] = target_folder_dir - self.use_flags[id] = learnwere_status + self.learnware_list[learnware_id] = new_learnware + self.learnware_zip_list[learnware_id] = target_zip_dir + self.learnware_folder_list[learnware_id] = target_folder_dir + self.use_flags[learnware_id] = learnwere_status self.count += 1 - return id, learnwere_status + return learnware_id, learnwere_status def delete_learnware(self, id: str) -> bool: """Delete Learnware from market From e029674ca8caaebe3f79bb94f87a27334c8bced5 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 2 Nov 2023 10:10:59 +0800 Subject: [PATCH 03/15] [MNT] format code by black --- learnware/market/easy2/checker.py | 4 ++-- learnware/specification/__init__.py | 8 +++++++- learnware/specification/regular/text/rkme.py | 3 ++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 1fc0fef..fa8f26c 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -91,7 +91,7 @@ class EasyStatisticalChecker(BaseChecker): if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - + def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): text_list = [] for i in range(num): @@ -106,7 +106,7 @@ class EasyStatisticalChecker(BaseChecker): else: raise ValueError("Type should be en or zh") return text_list - + if is_text: inputs = generate_random_text_list(10) else: diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 7fbf500..b27ef5b 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,9 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification +from .regular import ( + RegularStatsSpecification, + RKMEStatSpecification, + RKMETableSpecification, + RKMEImageSpecification, + RKMETextSpecification, +) diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index cc8659e..117b032 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -10,6 +10,7 @@ logger = get_module_logger("RKMETextSpecification", "INFO") class RKMETextSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): RKMETableSpecification.__init__(self, gamma, cuda_idx) self.language = [] @@ -59,7 +60,7 @@ class RKMETextSpecification(RKMETableSpecification): @staticmethod def get_language_ids(X): try: - text = ' '.join(X) + text = " ".join(X) lang = langdetect.detect(text) langs = langdetect.detect_langs(text) return [l.lang for l in langs] From 3827100c9c55c15c93eb11f7f04c5d22c9d015c2 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 2 Nov 2023 14:59:47 +0800 Subject: [PATCH 04/15] [FIX] fix bug in get_learnware_path_by_ids --- learnware/market/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index ef9d46a..51ddfbf 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -173,7 +173,7 @@ class LearnwareMarket: return self.learnware_organizer.get_learnwares(top, **kwargs) def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: - raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) + return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: return self.learnware_organizer.get_learnware_by_ids(id, **kwargs) From b16ed73f09446cdbb65540761e2e62147d9af123 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 11:26:12 +0800 Subject: [PATCH 05/15] [MNT] add check_status in get_learnwares, get_learnware_ids and search_learnware --- learnware/market/base.py | 80 ++++++++++++++++++++++++----- learnware/market/easy2/organizer.py | 51 +++++++++++++++--- learnware/market/easy2/searcher.py | 9 ++-- 3 files changed, 117 insertions(+), 23 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 51ddfbf..8b973a7 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -122,8 +122,25 @@ class LearnwareMarket: zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: - return self.learnware_searcher(user_info, **kwargs) + def search_learnware( + self, user_info: BaseUserInfo, check_status: int = None, **kwargs + ) -> Tuple[Any, List[Learnware]]: + """Search learnwares based on user_info from learnwares with check_status + + Parameters + ---------- + user_info : BaseUserInfo + User information for searching learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status + + Returns + ------- + Tuple[Any, List[Learnware]] + Search results + """ + return self.learnware_searcher(user_info, check_status, **kwargs) def delete_learnware(self, id: str, **kwargs) -> bool: return self.learnware_organizer.delete_learnware(id, **kwargs) @@ -166,11 +183,41 @@ class LearnwareMarket: id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def get_learnware_ids(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnware_ids(top, **kwargs) + def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]: + """get the list of learnware ids + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Raises + ------ + List[str] + the first top ids + """ + return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs) + + def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]: + """get the list of learnwares + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status - def get_learnwares(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnwares(top, **kwargs) + Raises + ------ + List[Learnware] + the first top learnwares + """ + return self.learnware_organizer.get_learnwares(top, check_status, **kwargs) def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) @@ -298,13 +345,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer") - def get_learnware_ids(self, top: int = None) -> List[str]: + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: """get the list of learnware ids Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status Raises ------ @@ -313,13 +363,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer") - def get_learnwares(self, top: int = None) -> List[Learnware]: + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: """get the list of learnwares Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status Raises ------ @@ -339,13 +392,16 @@ class BaseSearcher: def reset(self, organizer): self.learnware_oganizer = organizer - def __call__(self, user_info: BaseUserInfo): - """Search learnwares based on user_info + def __call__(self, user_info: BaseUserInfo, check_status: int = None): + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- user_info : BaseUserInfo user_info contains semantic_spec and stat_info + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status """ raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 1fbb9f6..dee7433 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -287,17 +287,52 @@ class EasyOrganizer(BaseOrganizer): logger.warning("Learnware ID '%s' NOT Found!" % (ids)) return None - def get_learnware_ids(self, top: int = None) -> List[str]: - if top is None: - return list(self.learnware_list.keys()) - else: - return list(self.learnware_list.keys())[:top] + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: + """Get learnware ids + + Parameters + ---------- + top : int, optional + The first top learnware ids to return, by default None + check_status : bool, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Returns + ------- + List[str] + Learnware ids + """ + if check_status is None: + filtered_ids = self.use_flags.keys() + elif check_status is True: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE] + elif check_status is False: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE] - def get_learnwares(self, top: int = None) -> List[str]: if top is None: - return list(self.learnware_list.values()) + return filtered_ids else: - return list(self.learnware_list.values())[:top] + return filtered_ids[:top] + + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: + """Get learnware list + + Parameters + ---------- + top : int, optional + The first top learnwares to return, by default None + check_status : bool, optional + - None: return all learnwares + - Others: return learnwares with check_status + + Returns + ------- + List[Learnware] + Learnware list + """ + learnware_ids = self.get_learnware_ids(top, check_status) + return [self.learnware_list[idx] for idx in learnware_ids] def __len__(self): return len(self.learnware_list) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 8934fda..7c9a783 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -610,9 +610,9 @@ class EasySearcher(BaseSearcher): self.stat_searcher.reset(organizer) def __call__( - self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" + self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - """Search learnwares based on user_info + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- @@ -620,6 +620,9 @@ class EasySearcher(BaseSearcher): user_info contains semantic_spec and stat_info max_search_num : int The maximum number of the returned learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status Returns ------- @@ -629,7 +632,7 @@ class EasySearcher(BaseSearcher): the third is the score of Learnware (mixture) the fourth is the list of Learnware (mixture), the size is search_num """ - learnware_list = self.learnware_oganizer.get_learnwares() + learnware_list = self.learnware_oganizer.get_learnwares(check_status=check_status) learnware_list = self.semantic_searcher(learnware_list, user_info) if len(learnware_list) == 0: From a135188d54eb76f1bf505301c2a5c86230985530 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 13:45:22 +0800 Subject: [PATCH 06/15] [MNT] modify update_learnware --- learnware/market/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 8b973a7..9300cb5 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -148,8 +148,8 @@ class LearnwareMarket: def update_learnware( self, id: str, - zip_path: str, - semantic_spec: dict, + zip_path: str = None, + semantic_spec: dict = None, checker_names: List[str] = None, check_status: int = None, **kwargs, @@ -174,6 +174,12 @@ class LearnwareMarket: int The final learnware check_status. """ + zip_path = self.get_learnware_path_by_ids(id) if zip_path is None else zip_path + semantic_spec = ( + self.get_learnware_by_ids(id).get_specification().get_semantic_spec() + if semantic_spec is None + else semantic_spec + ) update_status = self.check_learnware(zip_path, semantic_spec, checker_names) check_status = ( update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status From fd4e7f8dbf4da4da74612c8aba1ed96a0bf447c9 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 14:32:08 +0800 Subject: [PATCH 07/15] [FIX] fix bug in check_semantic_specification --- learnware/client/learnware_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 9c4671b..786f7fe 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -99,7 +99,7 @@ class LearnwareClient: @require_login def upload_learnware(self, learnware_zip_path, semantic_specification): - assert self._check_semantic_specification(semantic_specification) + assert self._check_semantic_specification(semantic_specification), "Semantic specification check failed!" file_hash = compute_file_hash(learnware_zip_path) url_upload = f"{self.host}/user/chunked_upload" @@ -405,14 +405,14 @@ class LearnwareClient: @staticmethod def _check_semantic_specification(semantic_spec): - return EasySemanticChecker.check_semantic_spec(semantic_spec) + return EasySemanticChecker.check_semantic_spec(semantic_spec) != EasySemanticChecker.INVALID_LEARNWARE @staticmethod def check_learnware(learnware_zip_path, semantic_specification=None): semantic_specification = ( get_semantic_specification() if semantic_specification is None else semantic_specification ) - LearnwareClient._check_semantic_specification(semantic_specification) + assert LearnwareClient._check_semantic_specification(semantic_specification), "Semantic specification check failed!" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: From a267fcefeda9fd19df37003c6d72bc00c6777aa6 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 15:26:08 +0800 Subject: [PATCH 08/15] [FIX] fix bug in list_semantic_specification_values --- learnware/client/learnware_client.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 786f7fe..e196102 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -40,16 +40,12 @@ def file_chunks(file_path): if not chunk: break yield chunk - pass - pass - pass def compute_file_hash(file_path): file_hash = hashlib.md5() for chunk in file_chunks(file_path): file_hash.update(chunk) - pass return file_hash.hexdigest() @@ -58,7 +54,6 @@ class SemanticSpecificationKey(Enum): TASK_TYPE = "Task" LIBRARY_TYPE = "Library" SENARIOES = "Scenario" - pass class LearnwareClient: @@ -85,7 +80,6 @@ class LearnwareClient: token = result["data"]["token"] self.headers = {"Authorization": f"Bearer {token}"} - pass @require_login def logout(self): @@ -95,7 +89,6 @@ class LearnwareClient: if result["code"] != 0: raise Exception("logout failed: " + json.dumps(result)) self.headers = None - pass @require_login def upload_learnware(self, learnware_zip_path, semantic_specification): @@ -126,7 +119,6 @@ class LearnwareClient: begin += len(chunk) bar.update(1) - pass bar.close() url_add = f"{self.host}/user/add_learnware_uploaded" @@ -169,9 +161,6 @@ class LearnwareClient: for chunk in response.iter_content(chunk_size=CHUNK_SIZE): f.write(chunk) bar.update(1) - pass - pass - pass @require_login def list_learnware(self): @@ -199,19 +188,16 @@ class LearnwareClient: stat_spec = list(stat_spec.values())[0] else: stat_spec = None - pass returns = [] with tempfile.NamedTemporaryFile(prefix="learnware_stat_", suffix=".json") as ftemp: if stat_spec is not None: stat_spec.save(ftemp.name) - pass with open(ftemp.name, "r") as fin: semantic_specification = specification.get_semantic_spec() if semantic_specification is None: semantic_specification = {} - pass semantic_specification.pop("Input", None) semantic_specification.pop("Output", None) @@ -220,7 +206,6 @@ class LearnwareClient: files = None else: files = {"statistical_specification": fin} - pass response = requests.post( url, @@ -246,9 +231,6 @@ class LearnwareClient: "matching": learnware["matching"], } ) - pass - pass - pass return returns @@ -261,7 +243,6 @@ class LearnwareClient: if result["code"] != 0: raise Exception("delete failed: " + json.dumps(result)) - pass def create_semantic_specification( self, @@ -295,8 +276,9 @@ class LearnwareClient: response = requests.get(url, headers=self.headers) result = response.json() semantic_conf = result["data"]["semantic_specification"] + print("!" * 100, semantic_conf) - return semantic_conf[key]["Values"] + return semantic_conf[key.value]["Values"] def load_learnware( self, @@ -412,7 +394,9 @@ class LearnwareClient: semantic_specification = ( get_semantic_specification() if semantic_specification is None else semantic_specification ) - assert LearnwareClient._check_semantic_specification(semantic_specification), "Semantic specification check failed!" + assert LearnwareClient._check_semantic_specification( + semantic_specification + ), "Semantic specification check failed!" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: From 0d04a50cd7417d1ded7a95d1cd53c9d42101ae99 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 15:27:23 +0800 Subject: [PATCH 09/15] [FIX] delete print --- learnware/client/learnware_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index e196102..976422b 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -276,8 +276,6 @@ class LearnwareClient: response = requests.get(url, headers=self.headers) result = response.json() semantic_conf = result["data"]["semantic_specification"] - print("!" * 100, semantic_conf) - return semantic_conf[key.value]["Values"] def load_learnware( From 465327ba9ec747d4ef5d6ce708e64c1cfbe3ade9 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 15:49:17 +0800 Subject: [PATCH 10/15] [MNT] add --no-dependencies for installing requirements --- learnware/client/container.py | 1 + learnware/client/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/learnware/client/container.py b/learnware/client/container.py index 243d580..a07fd7d 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -337,6 +337,7 @@ class ModelDockerContainer(ModelContainer): "install", "-r", f"{requirements_path_filter}", + "--no-dependencies", ] ) ) diff --git a/learnware/client/utils.py b/learnware/client/utils.py index 95a10c7..21bb12c 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -76,6 +76,7 @@ def install_environment(zip_path, conda_env): "install", "-r", f"{requirements_path_filter}", + "--no-dependencies", ] ) else: From adf1f81db3f752654d09492b1efd0dac5be8b0e0 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 16:33:44 +0800 Subject: [PATCH 11/15] [FIX] fix typo --- learnware/market/__init__.py | 2 +- learnware/market/module.py | 2 +- tests/test_market/test_easy.py | 4 ++-- tests/test_text_workflow/main.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index 7dd1bc6..649baa5 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -6,4 +6,4 @@ from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatist from .hetergeneous import HeterogeneousOrganizer, MappingFunction from .easy import EasyMarket -from .module import instatiate_learnware_market +from .module import instantiate_learnware_market diff --git a/learnware/market/module.py b/learnware/market/module.py index 6ad363b..3b55b52 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -10,7 +10,7 @@ MARKET_CONFIG = { } -def instatiate_learnware_market(market_id="default", name="easy", **kwargs): +def instantiate_learnware_market(market_id="default", name="easy", **kwargs): return LearnwareMarket( market_id=market_id, organizer=MARKET_CONFIG[name]["organizer"], diff --git a/tests/test_market/test_easy.py b/tests/test_market/test_easy.py index 16729e2..bb03839 100644 --- a/tests/test_market/test_easy.py +++ b/tests/test_market/test_easy.py @@ -11,7 +11,7 @@ from sklearn.model_selection import train_test_split from shutil import copyfile, rmtree import learnware -from learnware.market import instatiate_learnware_market, BaseUserInfo +from learnware.market import instantiate_learnware_market, BaseUserInfo import learnware.specification as specification curr_root = os.path.dirname(os.path.abspath(__file__)) @@ -43,7 +43,7 @@ class TestMarket(unittest.TestCase): def _init_learnware_market(self): """initialize learnware market""" - easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) + easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) return easy_market def test_prepare_learnware_randomly(self, learnware_num=5): diff --git a/tests/test_text_workflow/main.py b/tests/test_text_workflow/main.py index baa54f4..9ae39ba 100644 --- a/tests/test_text_workflow/main.py +++ b/tests/test_text_workflow/main.py @@ -9,7 +9,7 @@ from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningR import time import pickle -from learnware.market import instatiate_learnware_market, BaseUserInfo +from learnware.market import instantiate_learnware_market, BaseUserInfo from learnware.market import database_ops from learnware.learnware import Learnware import learnware.specification as specification @@ -120,7 +120,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo def prepare_market(): - text_market = instatiate_learnware_market(market_id="sst2", rebuild=True) + text_market = instantiate_learnware_market(market_id="sst2", rebuild=True) try: rmtree(learnware_pool_dir) except: @@ -144,10 +144,10 @@ def prepare_market(): def test_search(gamma=0.1, load_market=True): if load_market: - text_market = instatiate_learnware_market(market_id="sst2") + text_market = instantiate_learnware_market(market_id="sst2") else: prepare_market() - text_market = instatiate_learnware_market(market_id="sst2") + text_market = instantiate_learnware_market(market_id="sst2") logger.info("Number of items in the market: %d" % len(text_market)) select_list = [] From 4540d1a59bd932e2850e43214aee9e356d2770ae Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 17:22:17 +0800 Subject: [PATCH 12/15] [MNT] add check_status in add_learnware --- learnware/market/base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 9300cb5..6864d01 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -80,7 +80,7 @@ class LearnwareMarket: id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir ) - final_status = BaseChecker.INVALID_LEARNWARE + final_status = BaseChecker.NONUSABLE_LEARNWARE checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names for name in checker_names: @@ -98,7 +98,7 @@ class LearnwareMarket: return BaseChecker.INVALID_LEARNWARE def add_learnware( - self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs + self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, check_status: int = None, **kwargs ) -> Tuple[str, int]: """Add a learnware into the market. @@ -110,6 +110,8 @@ class LearnwareMarket: semantic_spec for new learnware, in dictionary format. checker_names : List[str], optional List contains checker names, by default None + check_status : int, optional + A flag indicating whether the learnware is usable, by default None. Returns ------- @@ -117,7 +119,11 @@ class LearnwareMarket: - str indicating model_id - int indicating the final learnware check_status """ - check_status = self.check_learnware(zip_path, semantic_spec, checker_names) + add_status = self.check_learnware(zip_path, semantic_spec, checker_names) + check_status = ( + add_status if check_status is None or add_status == BaseChecker.INVALID_LEARNWARE else check_status + ) + return self.learnware_organizer.add_learnware( zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) From ad85996ac03a3dcb0e2d877c8ecb3f21a013363b Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 17:38:31 +0800 Subject: [PATCH 13/15] [FIX] fix typo --- learnware/market/base.py | 8 ++++---- learnware/market/easy2/searcher.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 6864d01..b5c110d 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -399,10 +399,10 @@ class BaseOrganizer: class BaseSearcher: def __init__(self, organizer: BaseOrganizer = None): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def __call__(self, user_info: BaseUserInfo, check_status: int = None): """Search learnwares based on user_info from learnwares with check_status @@ -424,10 +424,10 @@ class BaseChecker: USABLE_LEARWARE = 1 def __init__(self, organizer: BaseOrganizer = None): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer def __call__(self, learnware: Learnware) -> int: """Check the utility of a learnware diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 7c9a783..106d47b 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -605,7 +605,7 @@ class EasySearcher(BaseSearcher): self.stat_searcher = EasyStatSearcher(organizer) def reset(self, organizer): - self.learnware_oganizer = organizer + self.learnware_organizer = organizer self.semantic_searcher.reset(organizer) self.stat_searcher.reset(organizer) @@ -632,7 +632,7 @@ class EasySearcher(BaseSearcher): the third is the score of Learnware (mixture) the fourth is the list of Learnware (mixture), the size is search_num """ - learnware_list = self.learnware_oganizer.get_learnwares(check_status=check_status) + learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) learnware_list = self.semantic_searcher(learnware_list, user_info) if len(learnware_list) == 0: From 9dcab883a478d807e7b39166521b5a3913b83213 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 21:01:42 +0800 Subject: [PATCH 14/15] [MNT] modify add_learnware --- learnware/market/base.py | 52 +++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index b5c110d..528cd46 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -70,35 +70,33 @@ class LearnwareMarket: def reload_market(self, **kwargs) -> bool: self.learnware_organizer.reload_market(**kwargs) - def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: + def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: try: - with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: - with zipfile.ZipFile(zip_path, mode="r") as z_file: - z_file.extractall(tempdir) - - pending_learnware = get_learnware_from_dirpath( - id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir - ) - - final_status = BaseChecker.NONUSABLE_LEARNWARE - checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names - - for name in checker_names: - checker = self.learnware_checker[name] - check_status = checker(pending_learnware) - final_status = max(final_status, check_status) - - if check_status == BaseChecker.INVALID_LEARNWARE: - return BaseChecker.INVALID_LEARNWARE - - return final_status - + final_status = BaseChecker.NONUSABLE_LEARNWARE + if len(checker_names): + with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: + with zipfile.ZipFile(zip_path, mode="r") as z_file: + z_file.extractall(tempdir) + + pending_learnware = get_learnware_from_dirpath( + id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir + ) + checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names + + for name in checker_names: + checker = self.learnware_checker[name] + check_status = checker(pending_learnware) + final_status = max(final_status, check_status) + + if check_status == BaseChecker.INVALID_LEARNWARE: + return BaseChecker.INVALID_LEARNWARE + return final_status except Exception as err: logger.warning(f"Check learnware failed! Due to {err}.") return BaseChecker.INVALID_LEARNWARE def add_learnware( - self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, check_status: int = None, **kwargs + self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs ) -> Tuple[str, int]: """Add a learnware into the market. @@ -110,8 +108,6 @@ class LearnwareMarket: semantic_spec for new learnware, in dictionary format. checker_names : List[str], optional List contains checker names, by default None - check_status : int, optional - A flag indicating whether the learnware is usable, by default None. Returns ------- @@ -119,11 +115,7 @@ class LearnwareMarket: - str indicating model_id - int indicating the final learnware check_status """ - add_status = self.check_learnware(zip_path, semantic_spec, checker_names) - check_status = ( - add_status if check_status is None or add_status == BaseChecker.INVALID_LEARNWARE else check_status - ) - + check_status = self.check_learnware(zip_path, semantic_spec, checker_names) return self.learnware_organizer.add_learnware( zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) From 1dce67953108ec88fe00efda3981406d68d4464e Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 23:08:30 +0800 Subject: [PATCH 15/15] [FIX] fix bugs in copyfile --- learnware/market/easy2/organizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index dee7433..cef55fc 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -208,7 +208,8 @@ class EasyOrganizer(BaseOrganizer): if new_learnware is None: return BaseChecker.INVALID_LEARNWARE - copyfile(zip_path, target_zip_dir) + if zip_path != target_zip_dir: + copyfile(zip_path, target_zip_dir) with zipfile.ZipFile(target_zip_dir, "r") as z_file: z_file.extractall(target_folder_dir)