From 71ebe0f1e6f743b3bbb1ea0ebb7600a7dd4b6495 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 16:49:35 +0800 Subject: [PATCH] [MNT] modify check for output dimension --- learnware/market/easy/checker.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 6a0720a..db19ca3 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -164,19 +164,34 @@ class EasyStatChecker(BaseChecker): # Check output shape if outputs[0].shape != learnware_model.output_shape: - message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" + message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" logger.warning(message) return self.INVALID_LEARNWARE, message - # Check output dimension - if semantic_spec["Task"]["Values"][0] in [ - "Classification", - "Regression", - ] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): - message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" + # Check output dimension for regression + if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( + semantic_spec["Output"]["Dimension"] + ): + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" logger.warning(message) return self.INVALID_LEARNWARE, message + # Check output dimension for classification + if semantic_spec["Task"]["Values"][0] == "Classification": + model_output_shape = learnware_model.output_shape[0] + semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) + + if model_output_shape == 1: + if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): + message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + else: + if model_output_shape != semantic_output_shape: + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + except Exception as e: message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." logger.warning(message)