From 0448d82a41df4103e9e707d92feceed5b94cd650 Mon Sep 17 00:00:00 2001 From: zouxiaochuan Date: Thu, 31 Aug 2023 22:50:52 +0800 Subject: [PATCH] [MNT] only check output shape for classification, regression and feature extraction --- learnware/market/easy.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 9a91f05..ab9617e 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -122,19 +122,20 @@ class EasyMarket(BaseMarket): inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) - # check output type - if isinstance(outputs, torch.Tensor): - outputs = outputs.detach().cpu().numpy() - if not isinstance(outputs, np.ndarray): - logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") - return cls.NONUSABLE_LEARNWARE - - # check output shape + # check output if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) pass if semantic_spec['Task']['Values'][0] in ('Classification', 'Regression', 'Feature Extraction'): + # check output type + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if not isinstance(outputs, np.ndarray): + logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") + return cls.NONUSABLE_LEARNWARE + + # check output shape output_dim = int(semantic_spec['Output']['Dimension']) if outputs[0].shape[0] != output_dim: logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")