Browse Source

[MNT] only check output shape for classification, regression and feature extraction

tags/v0.3.2
zouxiaochuan 2 years ago
parent
commit
0448d82a41
1 changed files with 9 additions and 8 deletions
  1. +9
    -8
      learnware/market/easy.py

+ 9
- 8
learnware/market/easy.py View File

@@ -122,19 +122,20 @@ class EasyMarket(BaseMarket):
inputs = np.random.randn(10, *input_shape)
outputs = learnware.predict(inputs)

# check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor")
return cls.NONUSABLE_LEARNWARE

# check output shape
# check output
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
pass
if semantic_spec['Task']['Values'][0] in ('Classification', 'Regression', 'Feature Extraction'):
# check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor")
return cls.NONUSABLE_LEARNWARE
# check output shape
output_dim = int(semantic_spec['Output']['Dimension'])
if outputs[0].shape[0] != output_dim:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")


Loading…
Cancel
Save