From c65d3b7261832c8d33d00a7266c537cf0166cb03 Mon Sep 17 00:00:00 2001 From: Gene Date: Wed, 22 Nov 2023 11:26:47 +0800 Subject: [PATCH] [MNT] add details in stat checker --- learnware/market/easy/checker.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index c4e3ef9..1acf2e4 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -131,7 +131,13 @@ class EasyStatChecker(BaseChecker): message += "\r\n" + traceback.format_exc() return self.INVALID_LEARNWARE, message - if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): + # Check length of input and output + if len(inputs) != len(outputs): + message = f"The learnware {learnware.id} output length must be equal to input length!" + logger.warning(message) + return self.INVALID_LEARNWARE, message + + if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: # Check output type if isinstance(outputs, torch.Tensor): outputs = outputs.detach().cpu().numpy() @@ -142,11 +148,19 @@ class EasyStatChecker(BaseChecker): if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) + # Check output shape - if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != ( - int(semantic_spec["Output"]["Dimension"]), - ): - message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" + if outputs[0].shape != learnware_model.output_shape: + message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + + # Check output dimension + if semantic_spec["Task"]["Values"][0] in [ + "Classification", + "Regression", + ] and learnware_model.output_shape != int(semantic_spec["Output"]["Dimension"]): + message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" logger.warning(message) return self.INVALID_LEARNWARE, message