From f3a471ce88ab1a200be0b1d84c4e4698c0fdbeef Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 24 Apr 2023 14:16:07 +0800 Subject: [PATCH] [ENH] Update check procudre --- learnware/__init__.py | 8 ++------ learnware/learnware/reuse.py | 5 +++-- learnware/market/easy.py | 13 +++++++++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/learnware/__init__.py b/learnware/__init__.py index 122ea33..eea97ff 100644 --- a/learnware/__init__.py +++ b/learnware/__init__.py @@ -23,9 +23,5 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): logger.info(f"make learnware dir successfully!") ## ignore tensorflow warning - os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel - logger.info(f"The tensorflow log level is setted to {tf_loglevel}") - - -## call init method by default -init() + # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel + # logger.info(f"The tensorflow log level is setted to {tf_loglevel}") diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index 1526190..7c289cc 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -1,5 +1,6 @@ import torch import numpy as np + # import tensorflow as tf from typing import Tuple, Any, List, Union, Dict from cvxopt import matrix, solvers @@ -58,7 +59,7 @@ class JobSelectorReuser(BaseReuser): pred_y = pred_y.detach().cpu().numpy() # elif isinstance(pred_y, tf.Tensor): # pred_y = pred_y.numpy() - + if not isinstance(pred_y, np.ndarray): raise TypeError(f"Model output must be np.ndarray or torch.Tensor") @@ -300,7 +301,7 @@ class AveragingReuser(BaseReuser): if not isinstance(pred_y, np.ndarray): raise TypeError(f"Model output must be np.ndarray or torch.Tensor") - + if self.mode == "mean": if mean_pred_y is None: mean_pred_y = pred_y diff --git a/learnware/market/easy.py b/learnware/market/easy.py index bf9c254..57c0631 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -83,15 +83,28 @@ class EasyMarket(BaseMarket): - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction """ try: + # check model instantiation learnware.instantiate_model() + except Exception as e: logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") return cls.NONUSABLE_LEARNWARE try: learnware_model = learnware.get_model() + + # check input shape inputs = np.random.randn(10, *learnware_model.input_shape) outputs = learnware.predict(inputs) + + # check output type + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if not isinstance(outputs, np.ndarray): + logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") + return cls.NONUSABLE_LEARNWARE + + # check output shape if outputs.shape[1:] != learnware_model.output_shape: logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") return cls.NONUSABLE_LEARNWARE