| @@ -164,19 +164,34 @@ class EasyStatChecker(BaseChecker): | |||||
| # Check output shape | # Check output shape | ||||
| if outputs[0].shape != learnware_model.output_shape: | if outputs[0].shape != learnware_model.output_shape: | ||||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" | |||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| # Check output dimension | |||||
| if semantic_spec["Task"]["Values"][0] in [ | |||||
| "Classification", | |||||
| "Regression", | |||||
| ] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||||
| # Check output dimension for regression | |||||
| if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( | |||||
| semantic_spec["Output"]["Dimension"] | |||||
| ): | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| # Check output dimension for classification | |||||
| if semantic_spec["Task"]["Values"][0] == "Classification": | |||||
| model_output_shape = learnware_model.output_shape[0] | |||||
| semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) | |||||
| if model_output_shape == 1: | |||||
| if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): | |||||
| message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" | |||||
| logger.warning(message) | |||||
| return self.INVALID_LEARNWARE, message | |||||
| else: | |||||
| if model_output_shape != semantic_output_shape: | |||||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" | |||||
| logger.warning(message) | |||||
| return self.INVALID_LEARNWARE, message | |||||
| except Exception as e: | except Exception as e: | ||||
| message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." | message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." | ||||
| logger.warning(message) | logger.warning(message) | ||||