Browse Source

[ENH] Update check procudre

tags/v0.3.2
bxdd 3 years ago
parent
commit
f3a471ce88
3 changed files with 18 additions and 8 deletions
  1. +2
    -6
      learnware/__init__.py
  2. +3
    -2
      learnware/learnware/reuse.py
  3. +13
    -0
      learnware/market/easy.py

+ 2
- 6
learnware/__init__.py View File

@@ -23,9 +23,5 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
logger.info(f"make learnware dir successfully!")

## ignore tensorflow warning
os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel
logger.info(f"The tensorflow log level is setted to {tf_loglevel}")


## call init method by default
init()
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel
# logger.info(f"The tensorflow log level is setted to {tf_loglevel}")

+ 3
- 2
learnware/learnware/reuse.py View File

@@ -1,5 +1,6 @@
import torch
import numpy as np

# import tensorflow as tf
from typing import Tuple, Any, List, Union, Dict
from cvxopt import matrix, solvers
@@ -58,7 +59,7 @@ class JobSelectorReuser(BaseReuser):
pred_y = pred_y.detach().cpu().numpy()
# elif isinstance(pred_y, tf.Tensor):
# pred_y = pred_y.numpy()
if not isinstance(pred_y, np.ndarray):
raise TypeError(f"Model output must be np.ndarray or torch.Tensor")

@@ -300,7 +301,7 @@ class AveragingReuser(BaseReuser):

if not isinstance(pred_y, np.ndarray):
raise TypeError(f"Model output must be np.ndarray or torch.Tensor")
if self.mode == "mean":
if mean_pred_y is None:
mean_pred_y = pred_y


+ 13
- 0
learnware/market/easy.py View File

@@ -83,15 +83,28 @@ class EasyMarket(BaseMarket):
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
"""
try:
# check model instantiation
learnware.instantiate_model()

except Exception as e:
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}")
return cls.NONUSABLE_LEARNWARE

try:
learnware_model = learnware.get_model()

# check input shape
inputs = np.random.randn(10, *learnware_model.input_shape)
outputs = learnware.predict(inputs)

# check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor")
return cls.NONUSABLE_LEARNWARE

# check output shape
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE


Loading…
Cancel
Save