Browse Source

[MNT] change single checker to multiple checker

tags/v0.3.2
Gene 2 years ago
parent
commit
0876597482
7 changed files with 131 additions and 95 deletions
  1. +1
    -1
      learnware/market/__init__.py
  2. +49
    -18
      learnware/market/base.py
  3. +1
    -1
      learnware/market/easy2/__init__.py
  4. +68
    -26
      learnware/market/easy2/checker.py
  5. +7
    -44
      learnware/market/easy2/organizer.py
  6. +2
    -2
      learnware/market/evolve_anchor/organizer.py
  7. +3
    -3
      learnware/market/module.py

+ 1
- 1
learnware/market/__init__.py View File

@@ -2,7 +2,7 @@ from .anchor import AnchoredUserInfo, AnchoredOrganizer
from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher
from .evolve_anchor import EvolveAnchoredOrganizer
from .evolve import EvolvedOrganizer
from .easy2 import EasyChecker, EasyOrganizer, EasySearcher
from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker
from .hetergeneous import HeterogeneousOrganizer, MappingFunction

from .easy import EasyMarket


+ 49
- 18
learnware/market/base.py View File

@@ -1,11 +1,12 @@
import os
import torch
import tempfile
import traceback
import numpy as np


from typing import Tuple, Any, List, Union
from ..learnware import Learnware
from ..learnware import Learnware, get_learnware_from_dirpath
from ..logger import get_module_logger

logger = get_module_logger("market_base", "INFO")
@@ -51,27 +52,57 @@ class LearnwareMarket:
self,
market_id: str = None,
organizer: "BaseOrganizer" = None,
checker: "BaseChecker" = None,
searcher: "BaseSearcher" = None,
checker_list: List["BaseChecker"] = None,
rebuild=False,
):
self.market_id = market_id
self.learnware_organizer = BaseOrganizer() if organizer is None else organizer
self.learnware_checker = BaseChecker() if checker is None else checker
self.learnware_checker.reset(organizer=self.learnware_organizer)
self.learnware_organizer.reset(market_id=market_id, checker=self.learnware_checker)
self.learnware_organizer.reset(market_id=market_id)
self.learnware_organizer.reload_market(rebuild=rebuild)
self.learnware_searcher = BaseSearcher() if searcher is None else searcher
self.learnware_searcher.reset(organizer=self.learnware_organizer)
if checker_list is None:
self.learnware_checker = {"BaseChecker": BaseChecker()}
else:
self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list}
for name, checker in self.learnware_checker.items():
checker.reset(organizer=self.learnware_organizer)

def reload_market(self, **kwargs) -> bool:
self.learnware_organizer.reload_market(**kwargs)

def check_learnware(self, learnware: Learnware, **kwargs) -> bool:
return self.learnware_checker(learnware, **kwargs)

def add_learnware(self, zip_path: str, semantic_spec: dict, **kwargs) -> Tuple[str, bool]:
return self.learnware_organizer.add_learnware(zip_path, semantic_spec, **kwargs)
def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
try:
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)
pending_learnware = get_learnware_from_dirpath(
id="pending", semantic_spec=semantic_specification, learnware_dirpath=tempdir
)
final_status = BaseChecker.INVALID_LEARNWARE
checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names

for name in checker_names:
checker = self.learnware_checker[name]
check_status = checker(pending_learnware)
final_status = max(final_status, check_status)

if check_status == BaseChecker.INVALID_LEARNWARE:
return BaseChecker.INVALID_LEARNWARE
return final_status
except Exception as err:
logger.warning(f"Check learnware failed! Due to {err}.")
return BaseChecker.INVALID_LEARNWARE

def add_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> Tuple[str, bool]:
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.add_learnware(zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs)

def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]:
return self.learnware_searcher(user_info, **kwargs)
@@ -79,8 +110,9 @@ class LearnwareMarket:
def delete_learnware(self, id: str, **kwargs) -> bool:
return self.learnware_organizer.delete_learnware(id, **kwargs)

def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool:
return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, **kwargs)
def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs)

def get_learnware_ids(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnware_ids(top, **kwargs)
@@ -99,12 +131,11 @@ class LearnwareMarket:


class BaseOrganizer:
def __init__(self, market_id=None, checker: BaseChecker = None):
self.reset(market_id=market_id, checker=checker)
def __init__(self, market_id=None):
self.reset(market_id=market_id)

def reset(self, market_id=None, checker: BaseChecker = None, **kwargs):
def reset(self, market_id=None, **kwargs):
self.market_id = market_id
self.checker = checker

def reload_market(self, rebuild=False, **kwargs) -> bool:
"""Reload the learnware organizer when server restared.
@@ -117,7 +148,7 @@ class BaseOrganizer:

raise NotImplementedError("reload market is Not Implemented in BaseOrganizer")

def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]:
def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, bool]:
"""Add a learnware into the market.

.. note::
@@ -167,7 +198,7 @@ class BaseOrganizer:
"""
raise NotImplementedError("delete learnware is Not Implemented in BaseOrganizer")

def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, **kwargs) -> bool:
def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, check_status: int) -> bool:
"""
Update Learnware with id and content to be updated.



+ 1
- 1
learnware/market/easy2/__init__.py View File

@@ -1,3 +1,3 @@
from .organizer import EasyOrganizer
from .checker import EasyChecker
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatisticalChecker

+ 68
- 26
learnware/market/easy2/checker.py View File

@@ -3,71 +3,113 @@ import numpy as np
import torch

from ..base import BaseChecker
from ...config import C
from ...logger import get_module_logger

logger = get_module_logger("easy_checker", "INFO")


class EasyChecker(BaseChecker):
class EasySemanticChecker(BaseChecker):
def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()
try:
for key in C["semantic_specs"]:
value = semantic_spec[key]["Values"]
valid_type = C["semantic_specs"][key]["Type"]
assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch"
if valid_type == "Class":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) == 1, f"{key} must be unique"
assert value[0] in valid_list, f"{key} must be in {valid_list}"
elif valid_type == "Tag":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) >= 1, f"{key} cannot be empty"
for v in value:
assert v in valid_list, f"{key} must be in {valid_list}"
elif valid_type == "String":
assert isinstance(value, str), f"{key} must be string"
assert len(value) >= 1, f"{key} cannot be empty"

if semantic_spec["Data"]["Values"][0] == "Table":
assert semantic_spec["Input"] is not None, "Lack of input semantics"
dim = semantic_spec["Input"]["Dimension"]
for k, v in semantic_spec["Input"]["Description"].items():
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]:
assert semantic_spec["Output"] is not None, "Lack of output semantics"
dim = semantic_spec["Output"]["Dimension"]
for k, v in semantic_spec["Output"]["Description"].items():
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

return self.NONUSABLE_LEARNWARE

except Exception as err:
logger.warning(f"semantic_specification is not valid due to {err}!")
return self.INVALID_LEARNWARE


class EasyStatisticalChecker(BaseChecker):
def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()

try:
# check model instantiation
# Check model instantiation
learnware.instantiate_model()

except Exception as e:
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}")
return self.NONUSABLE_LEARNWARE
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE

try:
learnware_model = learnware.get_model()

# check input shape
# Check input shape
if semantic_spec["Data"]["Values"][0] == "Table":
input_shape = (semantic_spec["Input"]["Dimension"],)
else:
input_shape = learnware_model.input_shape
pass

# check rkme dimension
# Check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification")
return self.NONUSABLE_LEARNWARE
pass
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")
return self.INVALID_LEARNWARE

inputs = np.random.randn(10, *input_shape)
outputs = learnware.predict(inputs)

# check output
# Check output
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
pass
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"):
# check output type
# Check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor")
return self.NONUSABLE_LEARNWARE
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!")
return self.INVALID_LEARNWARE

# check output shape
# Check output shape
output_dim = int(semantic_spec["Output"]["Dimension"])
if outputs[0].shape[0] != output_dim:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return self.NONUSABLE_LEARNWARE
pass
else:
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return self.NONUSABLE_LEARNWARE
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE

except Exception as e:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}")
return self.NONUSABLE_LEARNWARE
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.")
return self.INVALID_LEARNWARE

return self.USABLE_LEARWARE
return self.USABLE_LEARWARE

+ 7
- 44
learnware/market/easy2/organizer.py View File

@@ -13,7 +13,6 @@ from shutil import copyfile, rmtree
from typing import Tuple, Any, List, Union, Dict

from .database_ops import DatabaseOperations
from .checker import EasyChecker
from ..base import LearnwareMarket, BaseUserInfo


@@ -95,42 +94,6 @@ class EasyOrganizer(BaseOrganizer):

"""
semantic_spec = copy.deepcopy(semantic_spec)

if not os.path.exists(zip_path):
logger.warning("Zip Path NOT Found! Fail to add learnware.")
return None, EasyChecker.INVALID_LEARNWARE

try:
if len(semantic_spec["Data"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Data.")
return None, EasyChecker.INVALID_LEARNWARE
if len(semantic_spec["Task"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Task.")
return None, EasyChecker.INVALID_LEARNWARE
if len(semantic_spec["Library"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Device.")
return None, EasyChecker.INVALID_LEARNWARE
if len(semantic_spec["Name"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please provide Name.")
return None, EasyChecker.INVALID_LEARNWARE
if len(semantic_spec["Description"]["Values"]) == 0 and len(semantic_spec["Scenario"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please provide Scenario or Description.")
return None, EasyChecker.INVALID_LEARNWARE
if (
semantic_spec["Data"]["Type"] != "Class"
or semantic_spec["Task"]["Type"] != "Class"
or semantic_spec["Library"]["Type"] != "Class"
or semantic_spec["Scenario"]["Type"] != "Tag"
or semantic_spec["Name"]["Type"] != "String"
or semantic_spec["Description"]["Type"] != "String"
):
logger.warning("Illegal semantic specification, please provide the right type.")
return None, EasyChecker.INVALID_LEARNWARE
except:
print(semantic_spec)
logger.warning("Illegal semantic specification, some keys are missing.")
return None, EasyChecker.INVALID_LEARNWARE

logger.info("Get new learnware from %s" % (zip_path))

id = id if id is not None else "%08d" % (self.count)
@@ -152,12 +115,12 @@ class EasyOrganizer(BaseOrganizer):
rmtree(target_folder_dir)
except:
pass
return None, EasyChecker.INVALID_LEARNWARE
return None, BaseChecker.INVALID_LEARNWARE

if new_learnware is None:
return None, EasyChecker.INVALID_LEARNWARE
return None, BaseChecker.INVALID_LEARNWARE

learnwere_status = check_status if check_status is not None else self.checker(new_learnware)
learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE

self.dbops.add_learnware(
id=id,
@@ -227,7 +190,7 @@ class EasyOrganizer(BaseOrganizer):
assert (
zip_path is None and semantic_spec is None
), f"at least one of 'zip_path' and 'semantic_spec' should not be None when update learnware"
assert check_status != EasyChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE"
assert check_status != BaseChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE"

if zip_path is None and check_status is not None:
logger.warning("check_status will be ignored when zip_path is None for learnware update")
@@ -252,12 +215,12 @@ class EasyOrganizer(BaseOrganizer):
id=id, semantic_spec=semantic_spec, learnware_dirpath=tempdir
)
except Exception:
return EasyChecker.INVALID_LEARNWARE
return BaseChecker.INVALID_LEARNWARE

if new_learnware is None:
return EasyChecker.INVALID_LEARNWARE
return BaseChecker.INVALID_LEARNWARE

learnwere_status = self.checker.check_learnware(new_learnware)
learnwere_status = BaseChecker.NONUSABLE_LEARNWARE
else:
learnwere_status = self.use_flags[id] if zip_path is None else check_status



+ 2
- 2
learnware/market/evolve_anchor/organizer.py View File

@@ -1,7 +1,7 @@
from typing import List

from ..evolve.organizer import EvolvedOrganizer
from ..anchor.organizer import AnchoredOrganizer, AnchoredUserInfo
from ..evolve import EvolvedOrganizer
from ..anchor import AnchoredOrganizer, AnchoredUserInfo
from ...logger import get_module_logger

logger = get_module_logger("evolve_anchor_organizer")


+ 3
- 3
learnware/market/module.py View File

@@ -1,11 +1,11 @@
from .base import LearnwareMarket
from .easy2 import EasyChecker, EasyOrganizer, EasySearcher
from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker

MARKET_CONFIG = {
"easy": {
"organizer": EasyOrganizer(),
"checker": EasyChecker(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatisticalChecker()],
}
}

@@ -14,7 +14,7 @@ def instatiate_learnware_market(market_id, name="easy", **kwargs):
return LearnwareMarket(
market_id=market_id,
organizer=MARKET_CONFIG[name]["organizer"],
checker=MARKET_CONFIG[name]["checker"],
searcher=MARKET_CONFIG[name]["searcher"],
checker_list=MARKET_CONFIG[name]["checker_list"],
**kwargs
)

Loading…
Cancel
Save