Browse Source

Merge pull request #19 from LAMDA-NJU/check_model_rkme_dim

[MNT] check model input with rkme dimension
tags/v0.3.2
zouxiaochuan GitHub 2 years ago
parent
commit
5224723f9b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 4 deletions
  1. +35
    -4
      learnware/market/easy.py

+ 35
- 4
learnware/market/easy.py View File

@@ -7,6 +7,7 @@ import numpy as np
import pandas as pd
from cvxopt import solvers, matrix
from typing import Tuple, Any, List, Union, Dict
import traceback

from .base import BaseMarket, BaseUserInfo
from .database_ops import DatabaseOperations
@@ -88,11 +89,15 @@ class EasyMarket(BaseMarket):
- The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
"""

semantic_spec = learnware.get_specification().get_semantic_spec()

try:
# check model instantiation
learnware.instantiate_model()

except Exception as e:
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}")
return cls.NONUSABLE_LEARNWARE

@@ -100,7 +105,21 @@ class EasyMarket(BaseMarket):
learnware_model = learnware.get_model()

# check input shape
inputs = np.random.randn(10, *learnware_model.input_shape)
if semantic_spec['Data']['Values'][0] == 'Table':
input_shape = (semantic_spec['Input']['Dimension'], )
else:
input_shape = learnware_model.input_shape
pass

# check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification")
return cls.NONUSABLE_LEARNWARE
pass

inputs = np.random.randn(10, *input_shape)
outputs = learnware.predict(inputs)

# check output type
@@ -111,11 +130,23 @@ class EasyMarket(BaseMarket):
return cls.NONUSABLE_LEARNWARE

# check output shape
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
pass
if semantic_spec['Task']['Values'][0] in ('Classification', 'Regression', 'Feature Extraction'):
output_dim = int(semantic_spec['Output']['Dimension'])
if outputs[0].shape[0] != output_dim:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE
pass
else:
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE

except Exception as e:
logger.exception
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}")
return cls.NONUSABLE_LEARNWARE



Loading…
Cancel
Save