diff --git a/learnware/market/hetergeneous/searcher.py b/learnware/market/hetergeneous/searcher.py new file mode 100644 index 0000000..2bf12f6 --- /dev/null +++ b/learnware/market/hetergeneous/searcher.py @@ -0,0 +1,28 @@ +from typing import List + +from .organizer import HeteroMapTableOrganizer +from ...learnware import Learnware +from ..base import BaseSearcher, BaseUserInfo +from ...logger import get_module_logger + +logger = get_module_logger("hetero_searcher") + +class HeteroMapTableSearcher(BaseSearcher): + def __init__(self, organizer: HeteroMapTableOrganizer = None): + super(HeteroMapTableSearcher, self).__init__(organizer) + + def __call__(self, user_info: BaseUserInfo, check_status: int = None) -> Learnware: + # todo: use specially assigned search_gamma for calculating mmd dist + learnware_list = self.learnware_oganizer.get_learnwares() + target_learnware, min_dist = None, None + user_hetero_spec = self.learnware_oganizer.generate_hetero_map_spec(user_info) + for learnware in learnware_list.values(): + learnware_hetero_spec = learnware.specification.get_stat_spec_by_name("HeteroSpecification") + mmd_dist = learnware_hetero_spec.dist(user_hetero_spec) + if target_learnware is None or mmd_dist < min_dist: + min_dist = mmd_dist + target_learnware = learnware + return target_learnware + + def reset(self, organizer): + self.learnware_oganizer = organizer \ No newline at end of file