Browse Source

[ENH] add organizer, searcher, checker for market

tags/v0.3.2
bxdd 2 years ago
parent
commit
4e4912cf5a
6 changed files with 155 additions and 32 deletions
  1. +47
    -32
      learnware/market/base.py
  2. +7
    -0
      learnware/market/easy/__init__.py
  3. +72
    -0
      learnware/market/easy/checker.py
  4. +11
    -0
      learnware/market/easy/organizer.py
  5. +8
    -0
      learnware/market/easy/searcher.py
  6. +10
    -0
      learnware/market/searcher.py

+ 47
- 32
learnware/market/base.py View File

@@ -1,11 +1,14 @@
import os import os
import torch
import traceback
import numpy as np import numpy as np
import pandas as pd
from typing import Tuple, Any, List, Union, Dict



from typing import Tuple, Any, List, Union
from ..learnware import Learnware from ..learnware import Learnware
from ..specification import RKMEStatSpecification
from ..logger import get_module_logger


logger = get_module_logger("market_base", "INFO")


class BaseUserInfo: class BaseUserInfo:
"""User Information for searching learnware""" """User Information for searching learnware"""
@@ -43,19 +46,13 @@ class BaseUserInfo:
class BaseMarket: class BaseMarket:
"""Base interface for market, it provide the interface of search/add/detele/update learnwares""" """Base interface for market, it provide the interface of search/add/detele/update learnwares"""


def __init__(self, market_id: str = None):
def __init__(self, market_id: str = None, checker: 'LearnwareChecker' = None):
self.market_id = market_id self.market_id = market_id
self.learnware_checker = LearnwareChecker() if checker is None else checker


def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool:
def reload_market(self, **kwargs) -> bool:
"""Reload the market when server restared. """Reload the market when server restared.

Parameters
----------
market_path : str
Directory for market data. '_IP_:_port_' for loading from database.
semantic_spec_list_path : str
Directory for available semantic_spec. Should be a json file.

Returns Returns
------- -------
bool bool
@@ -64,8 +61,7 @@ class BaseMarket:


raise NotImplementedError("reload market is Not Implemented") raise NotImplementedError("reload market is Not Implemented")


@classmethod
def check_learnware(cls, learnware: Learnware) -> bool:
def check_learnware(self, learnware: Learnware) -> bool:
"""Check the utility of a learnware """Check the utility of a learnware


Parameters Parameters
@@ -77,7 +73,7 @@ class BaseMarket:
bool bool
A flag indicating whether the learnware can be accepted. A flag indicating whether the learnware can be accepted.
""" """
return True
return self.learnware_checker(learnware)


def add_learnware( def add_learnware(
self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str
@@ -221,9 +217,7 @@ class LearnwareOrganizer:


raise NotImplementedError("reload market is Not Implemented") raise NotImplementedError("reload market is Not Implemented")
def add_learnware(
self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str
) -> Tuple[str, bool]:
def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]:
"""Add a learnware into the market. """Add a learnware into the market.


.. note:: .. note::
@@ -233,22 +227,17 @@ class LearnwareOrganizer:


Parameters Parameters
---------- ----------
learnware_name : str
Name of new learnware.
model_path : str
zip_path : str
Filepath for learnware model, a zipped file. Filepath for learnware model, a zipped file.
stat_spec_path : str
Filepath for statistical specification, a '.npy' file.
How to pass parameters requires further discussion.
semantic_spec : dict semantic_spec : dict
semantic_spec for new learnware, in dictionary format. semantic_spec for new learnware, in dictionary format.
desc : str
Brief desciption for new learnware.


Returns Returns
------- -------
Tuple[str, bool]
str indicating model_id, bool indicating whether the learnware is added successfully.
Tuple[str, int]
- str indicating model_id
- int indicating what the flag of learnware is added.



Raises Raises
------ ------
@@ -280,7 +269,33 @@ class LearnwareOrganizer:
raise NotImplementedError("delete learnware is Not Implemented") raise NotImplementedError("delete learnware is Not Implemented")


class LearnwareSearcher: class LearnwareSearcher:
def __init__(self, learnware_organizor):
def __init__(self, organizer):
self.learnware_organizer = organizer
def __call__(self, user_info: BaseUserInfo):
raise NotImplementedError("'__call__' method is not implemented in LearnwareSearcher")
def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]:
pass

class LearnwareChecker:
INVALID_LEARNWARE = -1
NONUSABLE_LEARNWARE = 0
USABLE_LEARWARE = 1
@classmethod
def __call__(cls, learnware: Learnware) -> int:
"""Check the utility of a learnware

Parameters
----------
learnware : Learnware

Returns
-------
int
A flag indicating whether the learnware can be accepted.
- The INVALID_LEARNWARE denotes the learnware does not pass the check
- The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
"""
raise NotImplementedError("'__call__' method is not implemented in LearnwareChecker")

+ 7
- 0
learnware/market/easy/__init__.py View File

@@ -0,0 +1,7 @@
from ..base import LearnwareSearcher, LearnwareOrganizer

class EasySearcher(LearnwareSearcher):
pass

class EasyOrganizer(LearnwareOrganizer):
pass

+ 72
- 0
learnware/market/easy/checker.py View File

@@ -0,0 +1,72 @@
import traceback

from ..base import LearnwareChecker
from ...logger import get_module_logger

logger = get_module_logger("easy_checker", "INFO")

class EasyChecker(LearnwareChecker):
@classmethod
def __call__(cls, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()

try:
# check model instantiation
learnware.instantiate_model()

except Exception as e:
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}")
return cls.NONUSABLE_LEARNWARE

try:
learnware_model = learnware.get_model()

# check input shape
if semantic_spec["Data"]["Values"][0] == "Table":
input_shape = (semantic_spec["Input"]["Dimension"],)
else:
input_shape = learnware_model.input_shape
pass

# check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification")
return cls.NONUSABLE_LEARNWARE
pass

inputs = np.random.randn(10, *input_shape)
outputs = learnware.predict(inputs)

# check output
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
pass

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"):
# check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor")
return cls.NONUSABLE_LEARNWARE

# check output shape
output_dim = int(semantic_spec["Output"]["Dimension"])
if outputs[0].shape[0] != output_dim:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE
pass
else:
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE

except Exception as e:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}")
return cls.NONUSABLE_LEARNWARE

return cls.USABLE_LEARWARE

+ 11
- 0
learnware/market/easy/organizer.py View File

@@ -0,0 +1,11 @@
import traceback

from ..base import LearnwareOrganizer
from ...logger import get_module_logger

logger = get_module_logger("easy_organizer")


class EasyOrganizer(LearnwareOrganizer):

+ 8
- 0
learnware/market/easy/searcher.py View File

@@ -0,0 +1,8 @@
from ..base import LearnwareSearcher
from ...logger import get_module_logger

logger = get_module_logger('easy_seacher')

class EasySearcher(LearnwareSearcher):
pass

+ 10
- 0
learnware/market/searcher.py View File

@@ -0,0 +1,10 @@


from typing import Tuple, Any, List

from .base import BaseUserInfo
from ..learnware import Learnware
from ..logger import get_module_logger

logger = get_module_logger('model')


Loading…
Cancel
Save