diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 68aa02f..1ae8aa7 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -101,7 +101,7 @@ class EasyStatChecker(BaseChecker): logger.warning("input shapes of model and semantic specifications are different") return self.INVALID_LEARNWARE - spec_type = parse_specification_type(learnware.get_specification()) + spec_type = parse_specification_type(learnware.get_specification().stat_spec) if spec_type is None: logger.warning(f"No valid specification is found in stat spec {spec_type}") return self.INVALID_LEARNWARE diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 8feb5d9..f86c06c 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -565,7 +565,7 @@ class EasyStatSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info) + self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info) if self.stat_spec_type is None: raise KeyError("No supported stat specification is given in the user info") @@ -646,7 +646,7 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] - if parse_specification_type(stat_spec=user_info.stat_info) is not None: + if parse_specification_type(stat_specs=user_info.stat_info) is not None: return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/market/utils.py b/learnware/market/utils.py index c0cc319..76d41b9 100644 --- a/learnware/market/utils.py +++ b/learnware/market/utils.py @@ -2,9 +2,8 @@ from ..specification import Specification def parse_specification_type( - stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] + stat_specs: dict, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"] ): - stat_specs = stat_spec.stat_spec for spec in spec_list: if spec in stat_specs: return spec diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 7503b4a..a131fd5 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -49,7 +49,7 @@ class JobSelectorReuser(BaseReuser): """ raw_user_data = user_data if isinstance(user_data[0], str): - stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification().stat_spec) assert ( stat_spec_type == "RKMETextSpecification" ), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string." @@ -97,7 +97,7 @@ class JobSelectorReuser(BaseReuser): user_data_num = len(user_data) return np.array([0] * user_data_num) else: - stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification()) + stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification().stat_spec) learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list ]