|
|
|
@@ -5,6 +5,7 @@ import random |
|
|
|
import string |
|
|
|
|
|
|
|
from ..base import BaseChecker |
|
|
|
from ..utils import parse_specification_type |
|
|
|
from ...config import C |
|
|
|
from ...logger import get_module_logger |
|
|
|
|
|
|
|
@@ -61,6 +62,23 @@ class EasySemanticChecker(BaseChecker): |
|
|
|
|
|
|
|
|
|
|
|
class EasyStatisticalChecker(BaseChecker): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): |
|
|
|
text_list = [] |
|
|
|
for i in range(num): |
|
|
|
length = random.randint(min_len, max_len) |
|
|
|
if text_type == "en": |
|
|
|
characters = string.ascii_letters + string.digits + string.punctuation |
|
|
|
result_str = "".join(random.choice(characters) for i in range(length)) |
|
|
|
text_list.append(result_str) |
|
|
|
elif text_type == "zh": |
|
|
|
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) |
|
|
|
text_list.append(result_str) |
|
|
|
else: |
|
|
|
raise ValueError("Type should be en or zh") |
|
|
|
return text_list |
|
|
|
|
|
|
|
def __call__(self, learnware): |
|
|
|
semantic_spec = learnware.get_specification().get_semantic_spec() |
|
|
|
|
|
|
|
@@ -76,41 +94,30 @@ class EasyStatisticalChecker(BaseChecker): |
|
|
|
try: |
|
|
|
learnware_model = learnware.get_model() |
|
|
|
# Check input shape |
|
|
|
if semantic_spec["Data"]["Values"][0] == "Table": |
|
|
|
input_shape = (semantic_spec["Input"]["Dimension"],) |
|
|
|
else: |
|
|
|
input_shape = learnware_model.input_shape |
|
|
|
|
|
|
|
# Check rkme dimension |
|
|
|
is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec |
|
|
|
if is_text: |
|
|
|
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification") |
|
|
|
else: |
|
|
|
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") |
|
|
|
if stat_spec is not None and not is_text: |
|
|
|
input_shape = learnware_model.input_shape |
|
|
|
|
|
|
|
## WHY: why write this? |
|
|
|
if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (int(semantic_spec["Input"]["Dimension"]),): |
|
|
|
logger.warning("input shapes of model and semantic specifications are different") |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
spec_type = parse_specification_type(learnware.get_specification()) |
|
|
|
if spec_type is None: |
|
|
|
logger.warning(f"No valid specification is found in stat spec {stat_spec}") |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
if spec_type == "RKMETableSpecification": |
|
|
|
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) |
|
|
|
if stat_spec.get_z().shape[1:] != input_shape: |
|
|
|
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): |
|
|
|
text_list = [] |
|
|
|
for i in range(num): |
|
|
|
length = random.randint(min_len, max_len) |
|
|
|
if text_type == "en": |
|
|
|
characters = string.ascii_letters + string.digits + string.punctuation |
|
|
|
result_str = "".join(random.choice(characters) for i in range(length)) |
|
|
|
text_list.append(result_str) |
|
|
|
elif text_type == "zh": |
|
|
|
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) |
|
|
|
text_list.append(result_str) |
|
|
|
else: |
|
|
|
raise ValueError("Type should be en or zh") |
|
|
|
return text_list |
|
|
|
|
|
|
|
if is_text: |
|
|
|
inputs = generate_random_text_list(10) |
|
|
|
else: |
|
|
|
inputs = np.random.randn(10, *input_shape) |
|
|
|
elif spec_type == "RKMETextSpecification": |
|
|
|
inputs = EasyStatisticalChecker._generate_random_text_list(10) |
|
|
|
elif spec_type == "RKMEImageSpecification": |
|
|
|
inputs = np.random.randint(0, 255, size=(10, *input_shape)) |
|
|
|
else: |
|
|
|
raise ValueError(f"not supported spec type for spec_type = {spec_type}") |
|
|
|
outputs = learnware.predict(inputs) |
|
|
|
# Check output |
|
|
|
if outputs.ndim == 1: |
|
|
|
@@ -129,8 +136,7 @@ class EasyStatisticalChecker(BaseChecker): |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
# Check output shape |
|
|
|
output_dim = int(semantic_spec["Output"]["Dimension"]) |
|
|
|
if outputs[0].shape[0] != output_dim: |
|
|
|
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): |
|
|
|
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
|