| @@ -131,7 +131,13 @@ class EasyStatChecker(BaseChecker): | |||||
| message += "\r\n" + traceback.format_exc() | message += "\r\n" + traceback.format_exc() | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): | |||||
| # Check length of input and output | |||||
| if len(inputs) != len(outputs): | |||||
| message = f"The learnware {learnware.id} output length must be equal to input length!" | |||||
| logger.warning(message) | |||||
| return self.INVALID_LEARNWARE, message | |||||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: | |||||
| # Check output type | # Check output type | ||||
| if isinstance(outputs, torch.Tensor): | if isinstance(outputs, torch.Tensor): | ||||
| outputs = outputs.detach().cpu().numpy() | outputs = outputs.detach().cpu().numpy() | ||||
| @@ -142,11 +148,19 @@ class EasyStatChecker(BaseChecker): | |||||
| if outputs.ndim == 1: | if outputs.ndim == 1: | ||||
| outputs = outputs.reshape(-1, 1) | outputs = outputs.reshape(-1, 1) | ||||
| # Check output shape | # Check output shape | ||||
| if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != ( | |||||
| int(semantic_spec["Output"]["Dimension"]), | |||||
| ): | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||||
| if outputs[0].shape != learnware_model.output_shape: | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" | |||||
| logger.warning(message) | |||||
| return self.INVALID_LEARNWARE, message | |||||
| # Check output dimension | |||||
| if semantic_spec["Task"]["Values"][0] in [ | |||||
| "Classification", | |||||
| "Regression", | |||||
| ] and learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||