Browse Source

[FIX] fix checker logic with many specification

tags/v0.3.2
bxdd 2 years ago
parent
commit
d54a56d5ae
2 changed files with 51 additions and 33 deletions
  1. +39
    -33
      learnware/market/easy2/checker.py
  2. +12
    -0
      learnware/market/utils.py

+ 39
- 33
learnware/market/easy2/checker.py View File

@@ -5,6 +5,7 @@ import random
import string

from ..base import BaseChecker
from ..utils import parse_specification_type
from ...config import C
from ...logger import get_module_logger

@@ -61,6 +62,23 @@ class EasySemanticChecker(BaseChecker):


class EasyStatisticalChecker(BaseChecker):

@staticmethod
def _generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list
def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()

@@ -76,41 +94,30 @@ class EasyStatisticalChecker(BaseChecker):
try:
learnware_model = learnware.get_model()
# Check input shape
if semantic_spec["Data"]["Values"][0] == "Table":
input_shape = (semantic_spec["Input"]["Dimension"],)
else:
input_shape = learnware_model.input_shape

# Check rkme dimension
is_text = "RKMETextSpecification" in learnware.get_specification().stat_spec
if is_text:
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextSpecification")
else:
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification")
if stat_spec is not None and not is_text:
input_shape = learnware_model.input_shape

## WHY: why write this?
if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (int(semantic_spec["Input"]["Dimension"]),):
logger.warning("input shapes of model and semantic specifications are different")
return self.INVALID_LEARNWARE
spec_type = parse_specification_type(learnware.get_specification())
if spec_type is None:
logger.warning(f"No valid specification is found in stat spec {stat_spec}")
return self.INVALID_LEARNWARE
if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")
return self.INVALID_LEARNWARE

def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list

if is_text:
inputs = generate_random_text_list(10)
else:
inputs = np.random.randn(10, *input_shape)
elif spec_type == "RKMETextSpecification":
inputs = EasyStatisticalChecker._generate_random_text_list(10)
elif spec_type == "RKMEImageSpecification":
inputs = np.random.randint(0, 255, size=(10, *input_shape))
else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")
outputs = learnware.predict(inputs)
# Check output
if outputs.ndim == 1:
@@ -129,8 +136,7 @@ class EasyStatisticalChecker(BaseChecker):
return self.INVALID_LEARNWARE

# Check output shape
output_dim = int(semantic_spec["Output"]["Dimension"])
if outputs[0].shape[0] != output_dim:
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]):
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE



+ 12
- 0
learnware/market/utils.py View File

@@ -0,0 +1,12 @@
from ..specification import Specification
from ..logger import get_module_logger

logger = get_module_logger('market_utils')

def parse_specification_type(stat_spec: Specification):
stat_specs = stat_spec.stat_spec
spec_list =['RKMETableSpecification', 'RKMETextSpecification', 'RKMEImageSpecification']
for spec in spec_list:
if spec in stat_specs:
return spec
return None

Loading…
Cancel
Save