Browse Source

[MNT] add details in stat checker

tags/v0.3.2
Gene 2 years ago
parent
commit
c65d3b7261
1 changed files with 19 additions and 5 deletions
  1. +19
    -5
      learnware/market/easy/checker.py

+ 19
- 5
learnware/market/easy/checker.py View File

@@ -131,7 +131,13 @@ class EasyStatChecker(BaseChecker):
message += "\r\n" + traceback.format_exc()
return self.INVALID_LEARNWARE, message

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"):
# Check length of input and output
if len(inputs) != len(outputs):
message = f"The learnware {learnware.id} output length must be equal to input length!"
logger.warning(message)
return self.INVALID_LEARNWARE, message

if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]:
# Check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
@@ -142,11 +148,19 @@ class EasyStatChecker(BaseChecker):

if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)

# Check output shape
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != (
int(semantic_spec["Output"]["Dimension"]),
):
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
if outputs[0].shape != learnware_model.output_shape:
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check output dimension
if semantic_spec["Task"]["Values"][0] in [
"Classification",
"Regression",
] and learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]):
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message



Loading…
Cancel
Save