Browse Source

[MNT] add check_status in get_learnwares, get_learnware_ids and search_learnware

tags/v0.3.2
Gene 2 years ago
parent
commit
b16ed73f09
3 changed files with 117 additions and 23 deletions
  1. +68
    -12
      learnware/market/base.py
  2. +43
    -8
      learnware/market/easy2/organizer.py
  3. +6
    -3
      learnware/market/easy2/searcher.py

+ 68
- 12
learnware/market/base.py View File

@@ -122,8 +122,25 @@ class LearnwareMarket:
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]:
return self.learnware_searcher(user_info, **kwargs)
def search_learnware(
self, user_info: BaseUserInfo, check_status: int = None, **kwargs
) -> Tuple[Any, List[Learnware]]:
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
user_info : BaseUserInfo
User information for searching learnwares
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status

Returns
-------
Tuple[Any, List[Learnware]]
Search results
"""
return self.learnware_searcher(user_info, check_status, **kwargs)

def delete_learnware(self, id: str, **kwargs) -> bool:
return self.learnware_organizer.delete_learnware(id, **kwargs)
@@ -166,11 +183,41 @@ class LearnwareMarket:
id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def get_learnware_ids(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnware_ids(top, **kwargs)
def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]:
"""get the list of learnware ids

Parameters
----------
top : int, optional
The first top element to return, by default None
check_status : int, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Raises
------
List[str]
the first top ids
"""
return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs)

def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]:
"""get the list of learnwares

Parameters
----------
top : int, optional
The first top element to return, by default None
check_status : int, optional
- None: return all learnwares
- Others: return learnwares with check_status

def get_learnwares(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnwares(top, **kwargs)
Raises
------
List[Learnware]
the first top learnwares
"""
return self.learnware_organizer.get_learnwares(top, check_status, **kwargs)

def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs)
@@ -298,13 +345,16 @@ class BaseOrganizer:
"""
raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer")

def get_learnware_ids(self, top: int = None) -> List[str]:
def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]:
"""get the list of learnware ids

Parameters
----------
top : int, optional
the first top element to return, by default None
The first top element to return, by default None
check_status : int, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Raises
------
@@ -313,13 +363,16 @@ class BaseOrganizer:
"""
raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer")

def get_learnwares(self, top: int = None) -> List[Learnware]:
def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]:
"""get the list of learnwares

Parameters
----------
top : int, optional
the first top element to return, by default None
The first top element to return, by default None
check_status : int, optional
- None: return all learnwares
- Others: return learnwares with check_status

Raises
------
@@ -339,13 +392,16 @@ class BaseSearcher:
def reset(self, organizer):
self.learnware_oganizer = organizer

def __call__(self, user_info: BaseUserInfo):
"""Search learnwares based on user_info
def __call__(self, user_info: BaseUserInfo, check_status: int = None):
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
user_info : BaseUserInfo
user_info contains semantic_spec and stat_info
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status
"""
raise NotImplementedError("'__call__' method is not implemented in BaseSearcher")



+ 43
- 8
learnware/market/easy2/organizer.py View File

@@ -287,17 +287,52 @@ class EasyOrganizer(BaseOrganizer):
logger.warning("Learnware ID '%s' NOT Found!" % (ids))
return None

def get_learnware_ids(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.keys())
else:
return list(self.learnware_list.keys())[:top]
def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]:
"""Get learnware ids

Parameters
----------
top : int, optional
The first top learnware ids to return, by default None
check_status : bool, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Returns
-------
List[str]
Learnware ids
"""
if check_status is None:
filtered_ids = self.use_flags.keys()
elif check_status is True:
filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE]
elif check_status is False:
filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE]

def get_learnwares(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.values())
return filtered_ids
else:
return list(self.learnware_list.values())[:top]
return filtered_ids[:top]

def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]:
"""Get learnware list

Parameters
----------
top : int, optional
The first top learnwares to return, by default None
check_status : bool, optional
- None: return all learnwares
- Others: return learnwares with check_status

Returns
-------
List[Learnware]
Learnware list
"""
learnware_ids = self.get_learnware_ids(top, check_status)
return [self.learnware_list[idx] for idx in learnware_ids]

def __len__(self):
return len(self.learnware_list)

+ 6
- 3
learnware/market/easy2/searcher.py View File

@@ -610,9 +610,9 @@ class EasySearcher(BaseSearcher):
self.stat_searcher.reset(organizer)

def __call__(
self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy"
self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy"
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
"""Search learnwares based on user_info
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
@@ -620,6 +620,9 @@ class EasySearcher(BaseSearcher):
user_info contains semantic_spec and stat_info
max_search_num : int
The maximum number of the returned learnwares
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status

Returns
-------
@@ -629,7 +632,7 @@ class EasySearcher(BaseSearcher):
the third is the score of Learnware (mixture)
the fourth is the list of Learnware (mixture), the size is search_num
"""
learnware_list = self.learnware_oganizer.get_learnwares()
learnware_list = self.learnware_oganizer.get_learnwares(check_status=check_status)
learnware_list = self.semantic_searcher(learnware_list, user_info)

if len(learnware_list) == 0:


Loading…
Cancel
Save