|
|
|
@@ -131,7 +131,13 @@ class EasyStatChecker(BaseChecker): |
|
|
|
message += "\r\n" + traceback.format_exc() |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): |
|
|
|
# Check length of input and output |
|
|
|
if len(inputs) != len(outputs): |
|
|
|
message = f"The learnware {learnware.id} output length must be equal to input length!" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: |
|
|
|
# Check output type |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
@@ -142,11 +148,19 @@ class EasyStatChecker(BaseChecker): |
|
|
|
|
|
|
|
if outputs.ndim == 1: |
|
|
|
outputs = outputs.reshape(-1, 1) |
|
|
|
|
|
|
|
# Check output shape |
|
|
|
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != ( |
|
|
|
int(semantic_spec["Output"]["Dimension"]), |
|
|
|
): |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" |
|
|
|
if outputs[0].shape != learnware_model.output_shape: |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
# Check output dimension |
|
|
|
if semantic_spec["Task"]["Values"][0] in [ |
|
|
|
"Classification", |
|
|
|
"Regression", |
|
|
|
] and learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): |
|
|
|
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" |
|
|
|
logger.warning(message) |
|
|
|
return self.INVALID_LEARNWARE, message |
|
|
|
|
|
|
|
|