|
|
|
@@ -83,15 +83,28 @@ class EasyMarket(BaseMarket): |
|
|
|
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction |
|
|
|
""" |
|
|
|
try: |
|
|
|
# check model instantiation |
|
|
|
learnware.instantiate_model() |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") |
|
|
|
return cls.NONUSABLE_LEARNWARE |
|
|
|
|
|
|
|
try: |
|
|
|
learnware_model = learnware.get_model() |
|
|
|
|
|
|
|
# check input shape |
|
|
|
inputs = np.random.randn(10, *learnware_model.input_shape) |
|
|
|
outputs = learnware.predict(inputs) |
|
|
|
|
|
|
|
# check output type |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
if not isinstance(outputs, np.ndarray): |
|
|
|
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") |
|
|
|
return cls.NONUSABLE_LEARNWARE |
|
|
|
|
|
|
|
# check output shape |
|
|
|
if outputs.shape[1:] != learnware_model.output_shape: |
|
|
|
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") |
|
|
|
return cls.NONUSABLE_LEARNWARE |
|
|
|
|