From d54a56d5aed2cfcfbbae27c8c5b58c8bc979f3a5 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 3 Nov 2023 02:59:21 +0800 Subject: [PATCH] [FIX] fix checker logic with many specification --- learnware/market/easy2/checker.py | 72 +++++++++++++++++-------------- learnware/market/utils.py | 12 ++++++ 2 files changed, 51 insertions(+), 33 deletions(-) create mode 100644 learnware/market/utils.py diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index fa8f26c..2ab04f5 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -5,6 +5,7 @@ import random import string from ..base import BaseChecker +from ..utils import parse_specification_type from ...config import C from ...logger import get_module_logger @@ -61,6 +62,23 @@ class EasySemanticChecker(BaseChecker): class EasyStatisticalChecker(BaseChecker): + + @staticmethod + def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list + def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() @@ -76,41 +94,30 @@ class EasyStatisticalChecker(BaseChecker): try: learnware_model = learnware.get_model() # Check input shape - if semantic_spec["Data"]["Values"][0] == "Table": - input_shape = (semantic_spec["Input"]["Dimension"],) - else: - input_shape = learnware_model.input_shape - - # Check rkme dimension - is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec - if is_text: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") - else: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") - if stat_spec is not None and not is_text: + input_shape = learnware_model.input_shape + + ## WHY: why write this? + if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (int(semantic_spec["Input"]["Dimension"]),): + logger.warning("input shapes of model and semantic specifications are different") + return self.INVALID_LEARNWARE + + spec_type = parse_specification_type(learnware.get_specification()) + if spec_type is None: + logger.warning(f"No valid specification is found in stat spec {stat_spec}") + return self.INVALID_LEARNWARE + + if spec_type == "RKMETableSpecification": + stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") return self.INVALID_LEARNWARE - - def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): - text_list = [] - for i in range(num): - length = random.randint(min_len, max_len) - if text_type == "en": - characters = string.ascii_letters + string.digits + string.punctuation - result_str = "".join(random.choice(characters) for i in range(length)) - text_list.append(result_str) - elif text_type == "zh": - result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) - text_list.append(result_str) - else: - raise ValueError("Type should be en or zh") - return text_list - - if is_text: - inputs = generate_random_text_list(10) - else: inputs = np.random.randn(10, *input_shape) + elif spec_type == "RKMETextSpecification": + inputs = EasyStatisticalChecker._generate_random_text_list(10) + elif spec_type == "RKMEImageSpecification": + inputs = np.random.randint(0, 255, size=(10, *input_shape)) + else: + raise ValueError(f"not supported spec type for spec_type = {spec_type}") outputs = learnware.predict(inputs) # Check output if outputs.ndim == 1: @@ -129,8 +136,7 @@ class EasyStatisticalChecker(BaseChecker): return self.INVALID_LEARNWARE # Check output shape - output_dim = int(semantic_spec["Output"]["Dimension"]) - if outputs[0].shape[0] != output_dim: + if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") return self.INVALID_LEARNWARE diff --git a/learnware/market/utils.py b/learnware/market/utils.py new file mode 100644 index 0000000..f3d4e35 --- /dev/null +++ b/learnware/market/utils.py @@ -0,0 +1,12 @@ +from ..specification import Specification +from ..logger import get_module_logger + +logger = get_module_logger('market_utils') + +def parse_specification_type(stat_spec: Specification): + stat_specs = stat_spec.stat_spec + spec_list =['RKMETableSpecification', 'RKMETextSpecification', 'RKMEImageSpecification'] + for spec in spec_list: + if spec in stat_specs: + return spec + return None \ No newline at end of file