Browse Source

[MNT] now the test workflow is runnable

tags/v0.3.2
bxdd 2 years ago
parent
commit
1d5bbce2ff
5 changed files with 68 additions and 70 deletions
  1. +21
    -26
      learnware/market/base.py
  2. +8
    -10
      learnware/market/easy/searcher.py
  3. +1
    -1
      learnware/market/heterogeneous/searcher.py
  4. +12
    -10
      tests/test_hetero_market/test_hetero.py
  5. +26
    -23
      tests/test_workflow/test_workflow.py

+ 21
- 26
learnware/market/base.py View File

@@ -57,36 +57,31 @@ class BaseUserInfo:


@dataclass
class SearchItem:
learnwares: Union[List[Learnware] ,Learnware]
score: float = 1.0
class SingleSearchItem:
learnware: Learnware
score: Optional[float] = None

@dataclass
class MultipleSearchItem:
learnwares: List[Learnware]
score: float
class SearchResults:
def __init__(self, single_results: Optional[List[SearchItem]] = None, multiple_results: Optional[List[SearchItem]] = None):
self.search_mapping: Dict[str, List[SearchItem]] = {
"single": [] if single_results is None else single_results,
"multiple": [] if multiple_results is None else multiple_results,
}

def get_results(self, name):
if name in self.search_mapping:
return self.search_mapping[name]
else:
raise KeyError(f"name '{name}' is not supported for search results")
def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None):
self.update_single_results([] if single_results is None else single_results)
self.update_multiple_results([] if multiple_results is None else multiple_results)
def get_single_results(self) -> List[SingleSearchItem]:
return self.single_results
def update_results(self, name, search_result: List[SearchItem]):
if name in self.search_mapping:
self.search_mapping[name] = search_result
else:
raise KeyError(f"name '{name}' is not supported for search results")
def get_multiple_results(self) -> List[MultipleSearchItem]:
return self.multiple_results
def sort(self, name=None, ascending: bool=False):
assert name is None or name in self.search_mapping, f"name '{name}' is not supported for search results"
for key, search_result in self.search_mapping.items():
if name is None or name == key:
#search_result = [x for x in search_result]
self.search_mapping[key] = sorted(search_result, key=lambda x:x.score, reverse=(not ascending))

def update_single_results(self, single_results: List[SingleSearchItem]):
self.single_results = single_results
def update_multiple_results(self, multiple_results: List[MultipleSearchItem]):
self.multiple_results = multiple_results

class LearnwareMarket:
"""Base interface for market, it provide the interface of search/add/detele/update learnwares"""


+ 8
- 10
learnware/market/easy/searcher.py View File

@@ -6,7 +6,7 @@ from typing import Tuple, List, Union, Optional

from .organizer import EasyOrganizer
from ..utils import parse_specification_type
from ..base import BaseUserInfo, BaseSearcher, SearchResults, SearchItem
from ..base import BaseUserInfo, BaseSearcher, SearchResults, SingleSearchItem, MultipleSearchItem
from ...learnware import Learnware
from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification, rkme_solve_qp
from ...logger import get_module_logger
@@ -65,7 +65,7 @@ class EasyExactSemanticSearcher(BaseSearcher):
if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec):
match_learnwares.append(learnware)
logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list)))
return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in match_learnwares])
return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares])

class EasyFuzzSemanticSearcher(BaseSearcher):
def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool:
@@ -181,7 +181,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher):
final_result = matched_learnware_tag

logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list)))
return SearchResults(single_results=[SearchItem(learnwares=_learnware) for _learnware in final_result])
return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result])


class EasyStatSearcher(BaseSearcher):
@@ -623,14 +623,12 @@ class EasyStatSearcher(BaseSearcher):

search_results = SearchResults()
search_results.update_results(
name="single",
search_result=[SearchItem(learnwares=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)]
search_results.update_single_results(
[SingleSearchItem(learnware=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)]
)
if mixture_score is not None and len(mixture_learnware_list) > 0:
search_results.update_results(
name="multiple",
search_result=[SearchItem(learnwares=mixture_learnware_list, score=mixture_score)]
search_results.update_multiple_results(
[MultipleSearchItem(learnwares=mixture_learnware_list, score=mixture_score)]
)
return search_results

@@ -672,7 +670,7 @@ class EasySearcher(BaseSearcher):
learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status)
semantic_search_result = self.semantic_searcher(learnware_list, user_info)

learnware_list = semantic_search_result.get_results(name="single")
learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()]
if len(learnware_list) == 0:
return SearchResults()



+ 1
- 1
learnware/market/heterogeneous/searcher.py View File

@@ -40,7 +40,7 @@ class HeteroSearcher(EasySearcher):
learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status)
semantic_search_result = self.semantic_searcher(learnware_list, user_info)

learnware_list = semantic_search_result.get_results(name="single")
learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()]
if len(learnware_list) == 0:
return SearchResults()



+ 12
- 10
tests/test_hetero_market/test_hetero.py View File

@@ -199,26 +199,28 @@ class TestMarket(unittest.TestCase):
semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"

user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = hetero_market.search_learnware(user_info)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
assert len(single_learnware_list) == 1, f"Exact semantic search failed!"
for learnware in single_learnware_list:
semantic_spec1 = learnware.get_specification().get_semantic_spec()
print("Choose learnware:", learnware.id, semantic_spec1)
assert len(single_result) == 1, f"Exact semantic search failed!"
for search_item in single_result:
semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
print("Choose learnware:", search_item.learnware.id, semantic_spec1)
assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"

semantic_spec["Name"]["Values"] = "laernwaer"
user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = hetero_market.search_learnware(user_info)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()

print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!"
for learnware in single_learnware_list:
semantic_spec1 = learnware.get_specification().get_semantic_spec()
print("Choose learnware:", learnware.id, semantic_spec1)
assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!"
for search_item in single_result:
semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
print("Choose learnware:", search_item.learnware.id, semantic_spec1)

def test_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num)


+ 26
- 23
tests/test_workflow/test_workflow.py View File

@@ -143,12 +143,13 @@ class TestWorkflow(unittest.TestCase):
semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}"

user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = easy_market.search_learnware(user_info)

search_result = easy_market.search_learnware(user_info)
single_result = search_result.get_single_results()
print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
for learnware in single_learnware_list:
print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())
for search_item in single_result:
print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec())

rmtree(test_folder) # rm -r test_folder

@@ -171,20 +172,21 @@ class TestWorkflow(unittest.TestCase):
user_spec = RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info)

assert len(single_learnware_list) >= 1, f"Statistical search failed!"
search_results = easy_market.search_learnware(user_info)

single_result = search_results.get_single_results()
multiple_result = search_results.get_multiple_results()
assert len(single_result) >= 1, f"Statistical search failed!"
print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")
for search_item in single_result:
print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}")
if len(multiple_result) > 0:
mixture_item = multiple_result[0]
print(f"mixture_score: {mixture_item.score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares])
print(f"mixture_learnware: {mixture_id}\n")

rmtree(test_folder) # rm -r test_folder

@@ -198,24 +200,25 @@ class TestWorkflow(unittest.TestCase):
stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})

_, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)

search_results = easy_market.search_learnware(user_info)
multiple_result = search_results.get_multiple_results()
mixture_item = multiple_result[0]
# Based on user information, the learnware market returns a list of learnwares (learnware_list)
# Use jobselector reuser to reuse the searched learnwares to make prediction
reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list)
reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares)
job_selector_predict_y = reuse_job_selector.predict(user_data=data_X)

# Use averaging ensemble reuser to reuse the searched learnwares to make prediction
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob")
reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob")
ensemble_predict_y = reuse_ensemble.predict(user_data=data_X)

# Use ensemble pruning reuser to reuse the searched learnwares to make prediction
reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification")
reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification")
reuse_ensemble.fit(train_X[-200:], train_y[-200:])
ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X)

# Use feature augment reuser to reuse the searched learnwares to make prediction
reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_learnware_list, mode="classification")
reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification")
reuse_feature_augment.fit(train_X[-200:], train_y[-200:])
feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X)



Loading…
Cancel
Save