From b16ed73f09446cdbb65540761e2e62147d9af123 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 3 Nov 2023 11:26:12 +0800 Subject: [PATCH] [MNT] add check_status in get_learnwares, get_learnware_ids and search_learnware --- learnware/market/base.py | 80 ++++++++++++++++++++++++----- learnware/market/easy2/organizer.py | 51 +++++++++++++++--- learnware/market/easy2/searcher.py | 9 ++-- 3 files changed, 117 insertions(+), 23 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 51ddfbf..8b973a7 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -122,8 +122,25 @@ class LearnwareMarket: zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: - return self.learnware_searcher(user_info, **kwargs) + def search_learnware( + self, user_info: BaseUserInfo, check_status: int = None, **kwargs + ) -> Tuple[Any, List[Learnware]]: + """Search learnwares based on user_info from learnwares with check_status + + Parameters + ---------- + user_info : BaseUserInfo + User information for searching learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status + + Returns + ------- + Tuple[Any, List[Learnware]] + Search results + """ + return self.learnware_searcher(user_info, check_status, **kwargs) def delete_learnware(self, id: str, **kwargs) -> bool: return self.learnware_organizer.delete_learnware(id, **kwargs) @@ -166,11 +183,41 @@ class LearnwareMarket: id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) - def get_learnware_ids(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnware_ids(top, **kwargs) + def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]: + """get the list of learnware ids + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Raises + ------ + List[str] + the first top ids + """ + return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs) + + def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]: + """get the list of learnwares + + Parameters + ---------- + top : int, optional + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status - def get_learnwares(self, top: int = None, **kwargs): - return self.learnware_organizer.get_learnwares(top, **kwargs) + Raises + ------ + List[Learnware] + the first top learnwares + """ + return self.learnware_organizer.get_learnwares(top, check_status, **kwargs) def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs) @@ -298,13 +345,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer") - def get_learnware_ids(self, top: int = None) -> List[str]: + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: """get the list of learnware ids Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnware ids + - Others: return learnware ids with check_status Raises ------ @@ -313,13 +363,16 @@ class BaseOrganizer: """ raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer") - def get_learnwares(self, top: int = None) -> List[Learnware]: + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: """get the list of learnwares Parameters ---------- top : int, optional - the first top element to return, by default None + The first top element to return, by default None + check_status : int, optional + - None: return all learnwares + - Others: return learnwares with check_status Raises ------ @@ -339,13 +392,16 @@ class BaseSearcher: def reset(self, organizer): self.learnware_oganizer = organizer - def __call__(self, user_info: BaseUserInfo): - """Search learnwares based on user_info + def __call__(self, user_info: BaseUserInfo, check_status: int = None): + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- user_info : BaseUserInfo user_info contains semantic_spec and stat_info + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status """ raise NotImplementedError("'__call__' method is not implemented in BaseSearcher") diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 1fbb9f6..dee7433 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -287,17 +287,52 @@ class EasyOrganizer(BaseOrganizer): logger.warning("Learnware ID '%s' NOT Found!" % (ids)) return None - def get_learnware_ids(self, top: int = None) -> List[str]: - if top is None: - return list(self.learnware_list.keys()) - else: - return list(self.learnware_list.keys())[:top] + def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]: + """Get learnware ids + + Parameters + ---------- + top : int, optional + The first top learnware ids to return, by default None + check_status : bool, optional + - None: return all learnware ids + - Others: return learnware ids with check_status + + Returns + ------- + List[str] + Learnware ids + """ + if check_status is None: + filtered_ids = self.use_flags.keys() + elif check_status is True: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE] + elif check_status is False: + filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE] - def get_learnwares(self, top: int = None) -> List[str]: if top is None: - return list(self.learnware_list.values()) + return filtered_ids else: - return list(self.learnware_list.values())[:top] + return filtered_ids[:top] + + def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]: + """Get learnware list + + Parameters + ---------- + top : int, optional + The first top learnwares to return, by default None + check_status : bool, optional + - None: return all learnwares + - Others: return learnwares with check_status + + Returns + ------- + List[Learnware] + Learnware list + """ + learnware_ids = self.get_learnware_ids(top, check_status) + return [self.learnware_list[idx] for idx in learnware_ids] def __len__(self): return len(self.learnware_list) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 8934fda..7c9a783 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -610,9 +610,9 @@ class EasySearcher(BaseSearcher): self.stat_searcher.reset(organizer) def __call__( - self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" + self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - """Search learnwares based on user_info + """Search learnwares based on user_info from learnwares with check_status Parameters ---------- @@ -620,6 +620,9 @@ class EasySearcher(BaseSearcher): user_info contains semantic_spec and stat_info max_search_num : int The maximum number of the returned learnwares + check_status : int, optional + - None: search from all learnwares + - Others: search from learnwares with check_status Returns ------- @@ -629,7 +632,7 @@ class EasySearcher(BaseSearcher): the third is the score of Learnware (mixture) the fourth is the list of Learnware (mixture), the size is search_num """ - learnware_list = self.learnware_oganizer.get_learnwares() + learnware_list = self.learnware_oganizer.get_learnwares(check_status=check_status) learnware_list = self.semantic_searcher(learnware_list, user_info) if len(learnware_list) == 0: