From a8806827c93e9490c9a2685620b250e8190cccbf Mon Sep 17 00:00:00 2001 From: zouxiaochuan Date: Mon, 28 Aug 2023 17:17:06 +0800 Subject: [PATCH] [MNT] check model input with rkme dimension --- learnware/market/easy.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index d6af5f5..e6fcdf8 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from cvxopt import solvers, matrix from typing import Tuple, Any, List, Union, Dict +import traceback from .base import BaseMarket, BaseUserInfo from .database_ops import DatabaseOperations @@ -88,11 +89,15 @@ class EasyMarket(BaseMarket): - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction """ + + semantic_spec = learnware.get_specification().get_semantic_spec() + try: # check model instantiation learnware.instantiate_model() except Exception as e: + traceback.print_exc() logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") return cls.NONUSABLE_LEARNWARE @@ -100,7 +105,21 @@ class EasyMarket(BaseMarket): learnware_model = learnware.get_model() # check input shape - inputs = np.random.randn(10, *learnware_model.input_shape) + if semantic_spec['Data']['Values'][0] == 'Table': + input_shape = (semantic_spec['Input']['Dimension'], ) + else: + input_shape = learnware_model.input_shape + pass + + # check rkme dimension + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + if stat_spec is not None: + if stat_spec.get_z().shape[1:] != input_shape: + logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") + return cls.NONUSABLE_LEARNWARE + pass + + inputs = np.random.randn(10, *input_shape) outputs = learnware.predict(inputs) # check output type @@ -111,11 +130,23 @@ class EasyMarket(BaseMarket): return cls.NONUSABLE_LEARNWARE # check output shape - if outputs.shape[1:] != learnware_model.output_shape: - logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") - return cls.NONUSABLE_LEARNWARE + if outputs.ndim == 1: + outputs = outputs.reshape(-1, 1) + pass + + if semantic_spec['Task']['Values'][0] in ('Classification', 'Regression', 'Feature Extraction'): + output_dim = int(semantic_spec['Output']['Dimension']) + if outputs[0].shape[0] != output_dim: + logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") + return cls.NONUSABLE_LEARNWARE + pass + else: + if outputs.shape[1:] != learnware_model.output_shape: + logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") + return cls.NONUSABLE_LEARNWARE except Exception as e: + logger.exception logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") return cls.NONUSABLE_LEARNWARE