|
|
|
@@ -122,19 +122,20 @@ class EasyMarket(BaseMarket): |
|
|
|
inputs = np.random.randn(10, *input_shape) |
|
|
|
outputs = learnware.predict(inputs) |
|
|
|
|
|
|
|
# check output type |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
if not isinstance(outputs, np.ndarray): |
|
|
|
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") |
|
|
|
return cls.NONUSABLE_LEARNWARE |
|
|
|
|
|
|
|
# check output shape |
|
|
|
# check output |
|
|
|
if outputs.ndim == 1: |
|
|
|
outputs = outputs.reshape(-1, 1) |
|
|
|
pass |
|
|
|
|
|
|
|
if semantic_spec['Task']['Values'][0] in ('Classification', 'Regression', 'Feature Extraction'): |
|
|
|
# check output type |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
if not isinstance(outputs, np.ndarray): |
|
|
|
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") |
|
|
|
return cls.NONUSABLE_LEARNWARE |
|
|
|
|
|
|
|
# check output shape |
|
|
|
output_dim = int(semantic_spec['Output']['Dimension']) |
|
|
|
if outputs[0].shape[0] != output_dim: |
|
|
|
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") |
|
|
|
|