|
|
|
@@ -164,19 +164,34 @@ class EasyStatChecker(BaseChecker): |
|
|
|
|
|
|
|
# Check output shape |
|
|
|
if outputs[0].shape != learnware_model.output_shape: |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
# Check output dimension |
|
|
|
if semantic_spec["Task"]["Values"][0] in [ |
|
|
|
"Classification", |
|
|
|
"Regression", |
|
|
|
] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" |
|
|
|
# Check output dimension for regression |
|
|
|
if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( |
|
|
|
semantic_spec["Output"]["Dimension"] |
|
|
|
): |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
# Check output dimension for classification |
|
|
|
if semantic_spec["Task"]["Values"][0] == "Classification": |
|
|
|
model_output_shape = learnware_model.output_shape[0] |
|
|
|
semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) |
|
|
|
|
|
|
|
if model_output_shape == 1: |
|
|
|
if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): |
|
|
|
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
else: |
|
|
|
if model_output_shape != semantic_output_shape: |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." |
|
|
|
logger.warning(message) |
|
|
|
|