Browse Source

[FIX] fix parse spec type error

tags/v0.3.2
bxdd 2 years ago
parent
commit
7b307846df
4 changed files with 6 additions and 7 deletions
  1. +1
    -1
      learnware/market/easy2/checker.py
  2. +2
    -2
      learnware/market/easy2/searcher.py
  3. +1
    -2
      learnware/market/utils.py
  4. +2
    -2
      learnware/reuse/job_selector.py

+ 1
- 1
learnware/market/easy2/checker.py View File

@@ -101,7 +101,7 @@ class EasyStatChecker(BaseChecker):
logger.warning("input shapes of model and semantic specifications are different")
return self.INVALID_LEARNWARE

spec_type = parse_specification_type(learnware.get_specification())
spec_type = parse_specification_type(learnware.get_specification().stat_spec)
if spec_type is None:
logger.warning(f"No valid specification is found in stat spec {spec_type}")
return self.INVALID_LEARNWARE


+ 2
- 2
learnware/market/easy2/searcher.py View File

@@ -565,7 +565,7 @@ class EasyStatSearcher(BaseSearcher):
max_search_num: int = 5,
search_method: str = "greedy",
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
self.stat_spec_type = parse_specification_type(stat_spec=user_info.stat_info)
self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info)
if self.stat_spec_type is None:
raise KeyError("No supported stat specification is given in the user info")

@@ -646,7 +646,7 @@ class EasySearcher(BaseSearcher):
if len(learnware_list) == 0:
return [], [], 0.0, []

if parse_specification_type(stat_spec=user_info.stat_info) is not None:
if parse_specification_type(stat_specs=user_info.stat_info) is not None:
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
else:
return None, learnware_list, 0.0, None

+ 1
- 2
learnware/market/utils.py View File

@@ -2,9 +2,8 @@ from ..specification import Specification


def parse_specification_type(
stat_spec: Specification, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"]
stat_specs: dict, spec_list=["RKMETableSpecification", "RKMETextSpecification", "RKMEImageSpecification"]
):
stat_specs = stat_spec.stat_spec
for spec in spec_list:
if spec in stat_specs:
return spec


+ 2
- 2
learnware/reuse/job_selector.py View File

@@ -49,7 +49,7 @@ class JobSelectorReuser(BaseReuser):
"""
raw_user_data = user_data
if isinstance(user_data[0], str):
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification())
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification().stat_spec)
assert (
stat_spec_type == "RKMETextSpecification"
), "stat_spec_type must be 'RKMETextSpecification' when user data is the List of string."
@@ -97,7 +97,7 @@ class JobSelectorReuser(BaseReuser):
user_data_num = len(user_data)
return np.array([0] * user_data_num)
else:
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification())
stat_spec_type = parse_specification_type(self.learnware_list[0].get_specification().stat_spec)
learnware_rkme_spec_list = [
learnware.specification.get_stat_spec_by_name(stat_spec_type) for learnware in self.learnware_list
]


Loading…
Cancel
Save