Browse Source

Merge branch 'main' into dev_rkme_image

tags/v0.3.2
Googol2002 GitHub 2 years ago
parent
commit
f8d6f288bb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 2064 additions and 277 deletions
  1. +1
    -1
      README.md
  2. +1
    -1
      docs/references/api.rst
  3. +2
    -2
      docs/start/client.rst
  4. +1
    -1
      docs/start/quick.rst
  5. +1
    -1
      docs/workflow/identify.rst
  6. +1
    -1
      examples/workflow_by_code/main.py
  7. +8
    -5
      learnware/market/__init__.py
  8. +0
    -140
      learnware/market/anchor.py
  9. +2
    -0
      learnware/market/anchor/__init__.py
  10. +62
    -0
      learnware/market/anchor/organizer.py
  11. +111
    -0
      learnware/market/anchor/searcher.py
  12. +252
    -69
      learnware/market/base.py
  13. +9
    -9
      learnware/market/database_ops.py
  14. +5
    -5
      learnware/market/easy.py
  15. +3
    -0
      learnware/market/easy2/__init__.py
  16. +115
    -0
      learnware/market/easy2/checker.py
  17. +176
    -0
      learnware/market/easy2/database_ops.py
  18. +313
    -0
      learnware/market/easy2/organizer.py
  19. +637
    -0
      learnware/market/easy2/searcher.py
  20. +1
    -0
      learnware/market/evolve/__init__.py
  21. +9
    -12
      learnware/market/evolve/organizer.py
  22. +1
    -0
      learnware/market/evolve_anchor/__init__.py
  23. +8
    -13
      learnware/market/evolve_anchor/organizer.py
  24. +1
    -0
      learnware/market/hetergeneous/__init__.py
  25. +6
    -12
      learnware/market/hetergeneous/organizer.py
  26. +20
    -0
      learnware/market/module.py
  27. +1
    -1
      learnware/specification/__init__.py
  28. +9
    -0
      learnware/specification/base.py
  29. +1
    -0
      learnware/specification/table/__init__.py
  30. +4
    -2
      learnware/specification/table/rkme.py
  31. +1
    -1
      learnware/specification/utils.py
  32. +10
    -0
      tests/test_market/learnware_example/README.md
  33. +27
    -0
      tests/test_market/learnware_example/environment.yaml
  34. +8
    -0
      tests/test_market/learnware_example/example.yaml
  35. +20
    -0
      tests/test_market/learnware_example/example_init.py
  36. +205
    -0
      tests/test_market/test_easy.py
  37. +31
    -0
      tests/test_specification/test_rkme.py
  38. +1
    -1
      tests/test_workflow/test_workflow.py

+ 1
- 1
README.md View File

@@ -178,7 +178,7 @@ For example, the following code is designed to work with Reduced Set Kernel Embe
```python
import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json"))
user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}


+ 1
- 1
docs/references/api.rst View File

@@ -11,7 +11,7 @@ Here you can find all ``learnware`` interfaces.
Market
====================

.. autoclass:: learnware.market.BaseMarket
.. autoclass:: learnware.market.LearnwareMarket
:members:

.. autoclass:: learnware.market.EasyMarket


+ 2
- 2
docs/start/client.rst View File

@@ -123,7 +123,7 @@ You can search learnware by providing a statistical specification. The statistic

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification()
@@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares
senarioes=[],
input_description={}, output_description={})

stat_spec = specification.rkme.RKMEStatSpecification()
stat_spec = specification.RKMEStatSpecification()
stat_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification()
specification.update_semantic_spec(semantic_spec)


+ 1
- 1
docs/start/quick.rst View File

@@ -170,7 +170,7 @@ For example, the code below executes learnware search when using Reduced Set Ker

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()

# unzip_path: directory for unzipped learnware zipfile
user_spec.load(os.path.join(unzip_path, "rkme.json"))


+ 1
- 1
docs/workflow/identify.rst View File

@@ -73,7 +73,7 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join("rkme.json"))
user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}


+ 1
- 1
examples/workflow_by_code/main.py View File

@@ -148,7 +148,7 @@ class LearnwareMarketWorkflow:
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
(


+ 8
- 5
learnware/market/__init__.py View File

@@ -1,6 +1,9 @@
from .anchor import AnchoredUserInfo, AnchoredMarket
from .base import BaseUserInfo, BaseMarket
from .evolve_anchor import EvolvedAnchoredMarket
from .evolve import EvolvedMarket
from .anchor import AnchoredUserInfo, AnchoredOrganizer
from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher
from .evolve_anchor import EvolvedAnchoredOrganizer
from .evolve import EvolvedOrganizer
from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker
from .hetergeneous import HeterogeneousOrganizer, MappingFunction

from .easy import EasyMarket
from .heterogeneous_feature import HeterogeneousFeatureMarket
from .module import instatiate_learnware_market

+ 0
- 140
learnware/market/anchor.py View File

@@ -1,140 +0,0 @@
import os
from typing import Tuple, Any, List, Union, Dict

from ..learnware import Learnware
from .base import BaseMarket, BaseUserInfo


class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(self, id: str, semantic_spec: dict = dict(), stat_info: dict = dict()):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_list = {} # id: Learnware

def add_anchor_learnware(self, learnware_id: str, learnware: Learnware):
"""Add the anchor learnware acquired from the market

Parameters
----------
learnware_id : str
Id of anchor learnware
learnware : Learnware
Anchor learnware for capturing user requirements
"""
self.anchor_learnware_list[learnware_id] = learnware

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item


class AnchoredMarket(BaseMarket):
"""Add the anchor design to the BaseMarket

Parameters
----------
BaseMarket : _type_
Basic market version
"""

def __init__(self, *args, **kwargs):
super(AnchoredMarket, self).__init__(*args, **kwargs)
self.anchor_learnware_list = {} # anchor_id: anchor learnware

def _update_anchor_learnware(self, anchor_id: str, anchor_learnware: Learnware):
"""Update anchor_learnware_list

Parameters
----------
anchor_id : str
Id of anchor learnware
anchor_learnware : Learnware
Anchor learnware
"""
self.anchor_learnware_list[anchor_id] = anchor_learnware

def _delete_anchor_learnware(self, anchor_id: str) -> bool:
"""Delete anchor learnware in anchor_learnware_list

Parameters
----------
anchor_id : str
Id of anchor learnware

Returns
-------
bool
True if the target anchor learnware is deleted successfully.

Raises
------
Exception
Raise an excpetion when given anchor_id is NOT found in anchor_learnware_list
"""
if not anchor_id in self.anchor_learnware_list:
raise Exception("Anchor learnware id:{} NOT Found!".format(anchor_id))

self.anchor_learnware_list.pop(anchor_id)
return True

def update_anchor_learnware_list(self, learnware_list: Dict[str, Learnware]):
"""Update anchor_learnware_list

Parameters
----------
learnware_list : Dict[str, Learnware]
Learnwares for updating anchor_learnware_list
"""
pass

def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Search anchor Learnwares from anchor_learnware_list based on user_info

Parameters
----------
user_info : AnchoredUserInfo
- user_info with semantic specifications and statistical information
- some statistical information calculated on previous anchor learnwares

Returns
-------
Tuple[Any, List[Learnware]]:
return two items:

- first is the usage of anchor learnwares, e.g., how to use anchors to calculate some statistical information
- second is a list of anchor learnwares
"""
pass

def search_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Find helpful learnwares from learnware_list based on user_info

Parameters
----------
user_info : AnchoredUserInfo
- user_info with semantic specifications and statistical information
- some statistical information calculated on anchor learnwares

Returns
-------
Tuple[Any, List[Any]]
return two items:

- first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided.
- second is a list of matched learnwares
"""
pass

+ 2
- 0
learnware/market/anchor/__init__.py View File

@@ -0,0 +1,2 @@
from .organizer import AnchoredOrganizer
from .searcher import AnchoredUserInfo

+ 62
- 0
learnware/market/anchor/organizer.py View File

@@ -0,0 +1,62 @@
from typing import List, Dict, Tuple, Any

from ..easy2.organizer import EasyOrganizer
from ...logger import get_module_logger
from ...learnware import Learnware
from ...specification import BaseStatSpecification

logger = get_module_logger("anchor_organizer")


class AnchoredOrganizer(EasyOrganizer):
"""Organize learnwares and enable them to continuously evolve"""

def __init__(self, *args, **kwargs):
super(AnchoredOrganizer, self).__init__(*args, **kwargs)
self.anchor_learnware_list = {} # anchor_id: anchor learnware

def _update_anchor_learnware(self, anchor_id: str, anchor_learnware: Learnware):
"""Update anchor_learnware_list

Parameters
----------
anchor_id : str
Id of anchor learnware
anchor_learnware : Learnware
Anchor learnware
"""
self.anchor_learnware_list[anchor_id] = anchor_learnware

def _delete_anchor_learnware(self, anchor_id: str) -> bool:
"""Delete anchor learnware in anchor_learnware_list

Parameters
----------
anchor_id : str
Id of anchor learnware

Returns
-------
bool
True if the target anchor learnware is deleted successfully.

Raises
------
Exception
Raise an excpetion when given anchor_id is NOT found in anchor_learnware_list
"""
if not anchor_id in self.anchor_learnware_list:
raise Exception("Anchor learnware id:{} NOT Found!".format(anchor_id))

self.anchor_learnware_list.pop(anchor_id)
return True

def update_anchor_learnware_list(self, learnware_list: Dict[str, Learnware]):
"""Update anchor_learnware_list

Parameters
----------
learnware_list : Dict[str, Learnware]
Learnwares for updating anchor_learnware_list
"""
pass

+ 111
- 0
learnware/market/anchor/searcher.py View File

@@ -0,0 +1,111 @@
from typing import List, Dict, Tuple, Any, Union

from ..base import BaseUserInfo
from ..easy2.searcher import EasySearcher
from ...logger import get_module_logger
from ...learnware import Learnware

logger = get_module_logger("anchor_searcher")


class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor id list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(
self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None
):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids

def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]):
"""Add the anchor learnware ids acquired from the market

Parameters
----------
learnware_ids : Union[str, List[str]]
Anchor learnware ids
"""
if isinstance(learnware_ids, str):
learnware_ids = [learnware_ids]
self.anchor_learnware_ids += learnware_ids

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item


class AnchoredSearcher(EasySearcher):
def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Search anchor Learnwares from anchor_learnware_list based on user_info

Parameters
----------
user_info : AnchoredUserInfo
- user_info with semantic specifications and statistical information
- some statistical information calculated on previous anchor learnwares

Returns
-------
Tuple[Any, List[Learnware]]:
return two items:

- first is the usage of anchor learnwares, e.g., how to use anchors to calculate some statistical information
- second is a list of anchor learnwares
"""
pass

def search_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Find helpful learnwares from learnware_list based on user_info

Parameters
----------
user_info : AnchoredUserInfo
- user_info with semantic specifications and statistical information
- some statistical information calculated on anchor learnwares

Returns
-------
Tuple[Any, List[Any]]
return two items:

- first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided.
- second is a list of matched learnwares
"""
pass

def __call__(self, user_info: AnchoredUserInfo, anchor_flag: bool = False) -> Tuple[Any, List[Learnware]]:
"""Search learnwares with anchor marget
- if 'anchor_flag' == True, search anchor Learnwares from anchor_learnware_list based on user_info
- if 'anchor_flag' == False, find helpful learnwares from learnware_list based on user_info

Parameters
----------
user_info : AnchoredUserInfo
- user_info with semantic specifications and statistical information
- some statistical information calculated on anchor learnwares

Returns
-------
Tuple[Any, List[Any]]
return two items:

- first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided.
- second is a list of matched learnwares
"""
if anchor_flag:
return self.search_anchor_learnware(user_info)
else:
return self.search_learnware(user_info)

+ 252
- 69
learnware/market/base.py View File

@@ -1,10 +1,12 @@
import os
import numpy as np
import pandas as pd
from typing import Tuple, Any, List, Union, Dict
import zipfile
import tempfile

from ..learnware import Learnware
from ..specification import RKMEStatSpecification

from typing import Tuple, Any, List, Union
from ..learnware import Learnware, get_learnware_from_dirpath
from ..logger import get_module_logger

logger = get_module_logger("market_base", "INFO")


class BaseUserInfo:
@@ -40,47 +42,165 @@ class BaseUserInfo:
return self.stat_info.get(name, None)


class BaseMarket:
class LearnwareMarket:
"""Base interface for market, it provide the interface of search/add/detele/update learnwares"""

def __init__(self, market_id: str = None):
def __init__(
self,
market_id: str = None,
organizer: "BaseOrganizer" = None,
searcher: "BaseSearcher" = None,
checker_list: List["BaseChecker"] = None,
rebuild=False,
):
self.market_id = market_id
self.learnware_organizer = BaseOrganizer() if organizer is None else organizer
self.learnware_organizer.reset(market_id=market_id)
self.learnware_organizer.reload_market(rebuild=rebuild)
self.learnware_searcher = BaseSearcher() if searcher is None else searcher
self.learnware_searcher.reset(organizer=self.learnware_organizer)

if checker_list is None:
self.learnware_checker = {"BaseChecker": BaseChecker()}
else:
self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list}
for name, checker in self.learnware_checker.items():
checker.reset(organizer=self.learnware_organizer)

def reload_market(self, **kwargs) -> bool:
self.learnware_organizer.reload_market(**kwargs)

def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
try:
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)

pending_learnware = get_learnware_from_dirpath(
id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir
)

final_status = BaseChecker.INVALID_LEARNWARE
checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names

def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool:
"""Reload the market when server restared.
for name in checker_names:
checker = self.learnware_checker[name]
check_status = checker(pending_learnware)
final_status = max(final_status, check_status)

if check_status == BaseChecker.INVALID_LEARNWARE:
return BaseChecker.INVALID_LEARNWARE

return final_status

except Exception as err:
logger.warning(f"Check learnware failed! Due to {err}.")
return BaseChecker.INVALID_LEARNWARE

def add_learnware(
self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs
) -> Tuple[str, int]:
"""Add a learnware into the market.

Parameters
----------
market_path : str
Directory for market data. '_IP_:_port_' for loading from database.
semantic_spec_list_path : str
Directory for available semantic_spec. Should be a json file.
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
checker_names : List[str], optional
List contains checker names, by default None

Returns
-------
bool
A flag indicating whether the market is reload successfully.
Tuple[str, int]
- str indicating model_id
- int indicating the final learnware check_status
"""

raise NotImplementedError("reload market is Not Implemented")

def check_learnware(self, learnware: Learnware) -> bool:
"""Check the utility of a learnware
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.add_learnware(
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]:
return self.learnware_searcher(user_info, **kwargs)

def delete_learnware(self, id: str, **kwargs) -> bool:
return self.learnware_organizer.delete_learnware(id, **kwargs)

def update_learnware(
self,
id: str,
zip_path: str,
semantic_spec: dict,
checker_names: List[str] = None,
check_status: int = None,
**kwargs,
) -> int:
"""Update learnware with zip_path and semantic_specification

Parameters
----------
learnware : Learnware
id : str
Learnware id
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
checker_names : List[str], optional
List contains checker names, by default None.
check_status : int, optional
A flag indicating whether the learnware is usable, by default None.

Returns
-------
int
The final learnware check_status.
"""
update_status = self.check_learnware(zip_path, semantic_spec, checker_names)
check_status = (
update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status
)

return self.learnware_organizer.update_learnware(
id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def get_learnware_ids(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnware_ids(top, **kwargs)

def get_learnwares(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnwares(top, **kwargs)

def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs)

def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
return self.learnware_organizer.get_learnware_by_ids(id, **kwargs)

def __len__(self):
return len(self.learnware_organizer)


class BaseOrganizer:
def __init__(self, market_id=None):
self.reset(market_id=market_id)

def reset(self, market_id=None, **kwargs):
self.market_id = market_id

def reload_market(self, rebuild=False, **kwargs) -> bool:
"""Reload the learnware organizer when server restared.

Returns
-------
bool
A flag indicating whether the learnware can be accepted.
A flag indicating whether the market is reload successfully.
"""
return True

def add_learnware(
self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str
) -> Tuple[str, bool]:
raise NotImplementedError("reload market is Not Implemented in BaseOrganizer")
def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, bool]:
"""Add a learnware into the market.

.. note::
@@ -90,22 +210,17 @@ class BaseMarket:

Parameters
----------
learnware_name : str
Name of new learnware.
model_path : str
zip_path : str
Filepath for learnware model, a zipped file.
stat_spec_path : str
Filepath for statistical specification, a '.npy' file.
How to pass parameters requires further discussion.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
desc : str
Brief desciption for new learnware.

Returns
-------
Tuple[str, bool]
str indicating model_id, bool indicating whether the learnware is added successfully.
Tuple[str, int]
- str indicating model_id
- int indicating what the flag of learnware is added.


Raises
------
@@ -113,26 +228,38 @@ class BaseMarket:
file for model or statistical specification not found

"""
raise NotImplementedError("add learnware is Not Implemented")
raise NotImplementedError("add learnware is Not Implemented in BaseOrganizer")

def search_learnware(self, user_info: BaseUserInfo) -> Tuple[Any, List[Learnware]]:
"""Search Learnware based on user_info
def delete_learnware(self, id: str) -> bool:
"""Delete a learnware from market

Parameters
----------
user_info : BaseUserInfo
user_info with emantic specifications and statistical information
id : str
id of learnware to be deleted

Returns
-------
Tuple[Any, List[Any]]
return two items:
bool
True if the target learnware is deleted successfully.

- first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided.
- second is a list of matched learnwares
Raises
------
Exception
Raise an excpetion when given id is NOT found in learnware list
"""
raise NotImplementedError("delete learnware is Not Implemented in BaseOrganizer")

raise NotImplementedError("search learnware is Not Implemented")
def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, check_status: int) -> bool:
"""
Update Learnware with id and content to be updated.

Parameters
----------
id : str
id of target learnware.
"""
raise NotImplementedError("update learnware is Not Implemented in BaseOrganizer")

def get_learnware_by_ids(self, id: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""
@@ -151,47 +278,103 @@ class BaseMarket:
- The returned items are search results.
- 'None' indicating the target id not found.
"""
raise NotImplementedError("search learnware is Not Implemented")
raise NotImplementedError("get_learnware_by_ids is not implemented in BaseOrganizer")

def delete_learnware(self, id: str) -> bool:
"""Delete a learnware from market
def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Get Zipped Learnware file by id

Parameters
----------
id : str
id of learnware to be deleted
ids : Union[str, List[str]]
Give a id or a list of ids
str: id of targer learware
List[str]: A list of ids of target learnwares

Returns
-------
bool
True if the target learnware is deleted successfully.
Union[Learnware, List[Learnware]]
Return the path for target learnware or list of path.
None for Learnware NOT Found.
"""
raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer")

def get_learnware_ids(self, top: int = None) -> List[str]:
"""get the list of learnware ids

Parameters
----------
top : int, optional
the first top element to return, by default None

Raises
------
Exception
Raise an excpetion when given id is NOT found in learnware list
List[str]
the first top ids
"""
raise NotImplementedError("delete learnware is Not Implemented")
raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer")

def get_learnwares(self, top: int = None) -> List[Learnware]:
"""get the list of learnwares

def update_learnware(self, id: str) -> bool:
Parameters
----------
top : int, optional
the first top element to return, by default None

Raises
------
List[Learnware]
the first top learnwares
"""
Update Learnware with id and content to be updated.
Empty interface. TODO
raise NotImplementedError("get_learnwares is not implemented in BaseOrganizer")

def __len__(self):
raise NotImplementedError("__len__ is not implemented in BaseOrganizer")


class BaseSearcher:
def __init__(self, organizer: BaseOrganizer = None):
self.learnware_oganizer = organizer

def reset(self, organizer):
self.learnware_oganizer = organizer

def __call__(self, user_info: BaseUserInfo):
"""Search learnwares based on user_info

Parameters
----------
id : str
id of target learnware.
user_info : BaseUserInfo
user_info contains semantic_spec and stat_info
"""
raise NotImplementedError("update learnware is Not Implemented")
raise NotImplementedError("'__call__' method is not implemented in BaseSearcher")

def get_semantic_spec_list(self) -> dict:
"""Return all semantic specifications available

class BaseChecker:
INVALID_LEARNWARE = -1
NONUSABLE_LEARNWARE = 0
USABLE_LEARWARE = 1

def __init__(self, organizer: BaseOrganizer = None):
self.learnware_oganizer = organizer

def reset(self, organizer):
self.learnware_oganizer = organizer

def __call__(self, learnware: Learnware) -> int:
"""Check the utility of a learnware

Parameters
----------
learnware : Learnware

Returns
-------
dict
All emantic specifications in dictionary format

int
A flag indicating whether the learnware can be accepted.
- The INVALID_LEARNWARE denotes the learnware does not pass the check
- The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
"""
raise NotImplementedError("get semantic spec list is not implemented")

raise NotImplementedError("'__call__' method is not implemented in BaseChecker")

+ 9
- 9
learnware/market/database_ops.py View File

@@ -117,25 +117,25 @@ class DatabaseOperations(object):
pass
pass

def update_learnware_semantic_spec(self, learnware_id: str, semantic_spec: dict):
def delete_learnware(self, id: str):
with self.engine.connect() as conn:
semantic_spec_str = json.dumps(semantic_spec)
conn.execute(
text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"),
dict(id=learnware_id, semantic_spec=semantic_spec_str),
)
conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id))
conn.commit()
pass
pass

def delete_learnware(self, id: str):
def update_learnware_semantic_specification(self, id: str, semantic_spec: dict):
with self.engine.connect() as conn:
conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id))
semantic_spec_str = json.dumps(semantic_spec)
r = conn.execute(
text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"),
dict(id=id, semantic_spec=semantic_spec_str),
)
conn.commit()
pass
pass

def update_learnware_semantic_specification(self, id: str, semantic_spec: dict):
def update_learnware_use_flag(self, id: str, semantic_spec: dict):
with self.engine.connect() as conn:
semantic_spec_str = json.dumps(semantic_spec)
r = conn.execute(


+ 5
- 5
learnware/market/easy.py View File

@@ -11,7 +11,7 @@ from cvxopt import solvers, matrix
from shutil import copyfile, rmtree
from typing import Tuple, Any, List, Union, Dict

from .base import BaseMarket, BaseUserInfo
from .base import LearnwareMarket, BaseUserInfo
from .database_ops import DatabaseOperations

from .. import utils
@@ -24,8 +24,8 @@ from ..specification import RKMEStatSpecification, Specification
logger = get_module_logger("market", "INFO")


class EasyMarket(BaseMarket):
"""EasyMarket provide an easy and simple implementation for BaseMarket
class EasyMarket(LearnwareMarket):
"""EasyMarket provide an easy and simple implementation for LearnwareMarket
- EasyMarket stores learnwares with file system and database
- EasyMarket search the learnwares with the match of semantical tag and the statistical RKME
- EasyMarket does not support the search between heterogeneous features learnwars
@@ -956,11 +956,11 @@ class EasyMarket(BaseMarket):
logger.warning("Learnware ID '%s' NOT Found!" % (ids))
return None

def update_learnware_semantic_spec(self, learnware_id: str, semantic_spec: dict) -> bool:
def update_learnware_semantic_specification(self, learnware_id: str, semantic_spec: dict) -> bool:
"""Update Learnware semantic_spec"""

# update database
self.dbops.update_learnware_semantic_spec(learnware_id=learnware_id, semantic_spec=semantic_spec)
self.dbops.update_learnware_semantic_specification(learnware_id=learnware_id, semantic_spec=semantic_spec)
# update file

folder_path = self.learnware_folder_list[learnware_id]


+ 3
- 0
learnware/market/easy2/__init__.py View File

@@ -0,0 +1,3 @@
from .organizer import EasyOrganizer
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatisticalChecker

+ 115
- 0
learnware/market/easy2/checker.py View File

@@ -0,0 +1,115 @@
import traceback
import numpy as np
import torch

from ..base import BaseChecker
from ...config import C
from ...logger import get_module_logger

logger = get_module_logger("easy_checker", "INFO")


class EasySemanticChecker(BaseChecker):
def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()
try:
for key in C["semantic_specs"]:
value = semantic_spec[key]["Values"]
valid_type = C["semantic_specs"][key]["Type"]
assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch"

if valid_type == "Class":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) == 1, f"{key} must be unique"
assert value[0] in valid_list, f"{key} must be in {valid_list}"

elif valid_type == "Tag":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) >= 1, f"{key} cannot be empty"
for v in value:
assert v in valid_list, f"{key} must be in {valid_list}"

elif valid_type == "String":
assert isinstance(value, str), f"{key} must be string"
assert len(value) >= 1, f"{key} cannot be empty"

if semantic_spec["Data"]["Values"][0] == "Table":
assert semantic_spec["Input"] is not None, "Lack of input semantics"
dim = semantic_spec["Input"]["Dimension"]
for k, v in semantic_spec["Input"]["Description"].items():
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]:
assert semantic_spec["Output"] is not None, "Lack of output semantics"
dim = semantic_spec["Output"]["Dimension"]
for k, v in semantic_spec["Output"]["Description"].items():
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

return self.NONUSABLE_LEARNWARE

except Exception as err:
logger.warning(f"semantic_specification is not valid due to {err}!")
return self.INVALID_LEARNWARE


class EasyStatisticalChecker(BaseChecker):
def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()

try:
# Check model instantiation
learnware.instantiate_model()

except Exception as e:
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE

try:
learnware_model = learnware.get_model()

# Check input shape
if semantic_spec["Data"]["Values"][0] == "Table":
input_shape = (semantic_spec["Input"]["Dimension"],)
else:
input_shape = learnware_model.input_shape

# Check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")
return self.INVALID_LEARNWARE

inputs = np.random.randn(10, *input_shape)
outputs = learnware.predict(inputs)

# Check output
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)

if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"):
# Check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!")
return self.INVALID_LEARNWARE

# Check output shape
output_dim = int(semantic_spec["Output"]["Dimension"])
if outputs[0].shape[0] != output_dim:
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE

except Exception as e:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.")
return self.INVALID_LEARNWARE

return self.USABLE_LEARWARE

+ 176
- 0
learnware/market/easy2/database_ops.py View File

@@ -0,0 +1,176 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine, text
from sqlalchemy import Column, Integer, Text, DateTime, String
import os
import json
from ...learnware import get_learnware_from_dirpath
from ...logger import get_module_logger

logger = get_module_logger("database")
DeclarativeBase = declarative_base()


class Learnware(DeclarativeBase):
__tablename__ = "tb_learnware"

id = Column(String(10), primary_key=True, nullable=False)
semantic_spec = Column(Text, nullable=False)
zip_path = Column(Text, nullable=False)
folder_path = Column(Text, nullable=False)
use_flag = Column(Text, nullable=False)

pass


class DatabaseOperations(object):
def __init__(self, url: str, database_name: str):
if url.startswith("sqlite"):
url = os.path.join(url, f"{database_name}.db")
else:
url = f"{url}/{database_name}"
pass

self.url = url
self.create_database_if_not_exists(url)

pass

def create_database_if_not_exists(self, url):
database_exists = True

if url.startswith("sqlite"):
# it is sqlite
start = url.find(":///")
path = url[start + 4 :]
if os.path.exists(path):
database_exists = True
pass
else:
database_exists = False
os.makedirs(os.path.dirname(path), exist_ok=True)
pass
pass
elif self.url.startswith("postgresql"):
# it is postgresql
dbname_start = url.rfind("/")
dbname = url[dbname_start + 1 :]
url_no_dbname = url[:dbname_start] + "/postgres"
engine = create_engine(url_no_dbname)

with engine.connect() as conn:
result = conn.execute(text("SELECT datname FROM pg_database;"))
db_list = set()

for row in result.fetchall():
db_list.add(row[0].lower())
pass

if dbname.lower() not in db_list:
database_exists = False
conn.execution_options(isolation_level="AUTOCOMMIT").execute(
text("CREATE DATABASE {0};".format(dbname))
)
pass
else:
database_exists = True
pass
pass
engine.dispose()
pass
else:
raise Exception(f"Unsupported database url: {self.url}")
pass

self.engine = create_engine(url, future=True)

if not database_exists:
DeclarativeBase.metadata.create_all(self.engine)
pass
pass

def clear_learnware_table(self):
with self.engine.connect() as conn:
conn.execute(text("DELETE FROM tb_learnware;"))
conn.commit()
pass
pass

def add_learnware(self, id: str, semantic_spec: dict, zip_path, folder_path, use_flag: str):
with self.engine.connect() as conn:
semantic_spec_str = json.dumps(semantic_spec)
conn.execute(
text(
(
"INSERT INTO tb_learnware (id, semantic_spec, zip_path, folder_path, use_flag)"
"VALUES (:id, :semantic_spec, :zip_path, :folder_path, :use_flag);"
)
),
dict(
id=id,
semantic_spec=semantic_spec_str,
zip_path=zip_path,
folder_path=folder_path,
use_flag=use_flag,
),
)
conn.commit()
pass
pass

def delete_learnware(self, id: str):
with self.engine.connect() as conn:
conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id))
conn.commit()
pass
pass

def update_learnware_semantic_specification(self, id: str, semantic_spec: dict):
with self.engine.connect() as conn:
semantic_spec_str = json.dumps(semantic_spec)
r = conn.execute(
text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"),
dict(id=id, semantic_spec=semantic_spec_str),
)
conn.commit()
pass
pass

def update_learnware_use_flag(self, id: str, use_flag: str):
with self.engine.connect() as conn:
r = conn.execute(
text("UPDATE tb_learnware SET use_flag=:use_flag WHERE id=:id;"),
dict(id=id, use_flag=use_flag),
)
conn.commit()
pass
pass

def load_market(self):
with self.engine.connect() as conn:
cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;"))

learnware_list = {}
zip_list = {}
folder_list = {}
use_flags = {}
max_count = 0

for id, semantic_spec, zip_path, folder_path, use_flag in cursor:
id = id.strip()
semantic_spec_dict = json.loads(semantic_spec)
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec_dict, learnware_dirpath=folder_path
)
logger.info(f"Load learnware: {id}")
learnware_list[id] = new_learnware
# assert new_learnware is not None
zip_list[id] = zip_path
folder_list[id] = folder_path
use_flags[id] = use_flag
max_count = max(max_count, int(id))
pass

return learnware_list, zip_list, folder_list, use_flags, max_count + 1
pass

pass

+ 313
- 0
learnware/market/easy2/organizer.py View File

@@ -0,0 +1,313 @@
import os
import json
import copy
import torch
import zipfile
import traceback
import tempfile
import numpy as np
import pandas as pd
from rapidfuzz import fuzz
from cvxopt import solvers, matrix
from shutil import copyfile, rmtree
from typing import Tuple, Any, List, Union, Dict

from .database_ops import DatabaseOperations
from ..base import LearnwareMarket, BaseUserInfo


from ... import utils
from ...config import C as conf
from ...logger import get_module_logger
from ...learnware import Learnware, get_learnware_from_dirpath
from ...specification import RKMEStatSpecification, Specification

from ..base import BaseOrganizer, BaseChecker
from ...logger import get_module_logger

logger = get_module_logger("easy_organizer")


class EasyOrganizer(BaseOrganizer):
def reload_market(self, rebuild=False) -> bool:
"""Reload the learnware organizer when server restared.

Returns
-------
bool
A flag indicating whether the market is reload successfully.
"""
self.market_store_path = os.path.join(conf.market_root_path, self.market_id)
self.learnware_pool_path = os.path.join(self.market_store_path, "learnware_pool")
self.learnware_zip_pool_path = os.path.join(self.learnware_pool_path, "zips")
self.learnware_folder_pool_path = os.path.join(self.learnware_pool_path, "unzipped_learnwares")
self.learnware_list = {} # id: Learnware
self.learnware_zip_list = {}
self.learnware_folder_list = {}
self.use_flags = {}
self.count = 0
self.semantic_spec_list = conf.semantic_specs
self.dbops = DatabaseOperations(conf.database_url, "market_" + self.market_id)

if rebuild:
logger.warning("Warning! You are trying to clear current database!")
try:
self.dbops.clear_learnware_table()
rmtree(self.learnware_pool_path)
except:
pass

os.makedirs(self.learnware_pool_path, exist_ok=True)
os.makedirs(self.learnware_zip_pool_path, exist_ok=True)
os.makedirs(self.learnware_folder_pool_path, exist_ok=True)
(
self.learnware_list,
self.learnware_zip_list,
self.learnware_folder_list,
self.use_flags,
self.count,
) = self.dbops.load_market()

def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]:
"""Add a learnware into the market.

Parameters
----------
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
check_status: int
A flag indicating whether the learnware is usable.

Returns
-------
Tuple[str, int]
- str indicating model_id
- int indicating the final learnware check_status
"""
if check_status == BaseChecker.INVALID_LEARNWARE:
logger.warning("Learnware is invalid!")
return None, BaseChecker.INVALID_LEARNWARE

semantic_spec = copy.deepcopy(semantic_spec)
logger.info("Get new learnware from %s" % (zip_path))

id = "%08d" % (self.count)
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, id)
copyfile(zip_path, target_zip_dir)

with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)
logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir))

try:
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
except:
try:
os.remove(target_zip_dir)
rmtree(target_folder_dir)
except:
pass
return None, BaseChecker.INVALID_LEARNWARE

if new_learnware is None:
return None, BaseChecker.INVALID_LEARNWARE

learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE

self.dbops.add_learnware(
id=id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
use_flag=learnwere_status,
)

self.learnware_list[id] = new_learnware
self.learnware_zip_list[id] = target_zip_dir
self.learnware_folder_list[id] = target_folder_dir
self.use_flags[id] = learnwere_status
self.count += 1
return id, learnwere_status

def delete_learnware(self, id: str) -> bool:
"""Delete Learnware from market

Parameters
----------
id : str
Learnware to be deleted

Returns
-------
bool
True for successful operation.
False for id not found.
"""
if not id in self.learnware_list:
logger.warning("Learnware id:'{}' NOT Found!".format(id))
return False

zip_dir = self.learnware_zip_list[id]
os.remove(zip_dir)
folder_dir = self.learnware_folder_list[id]
rmtree(folder_dir)
self.learnware_list.pop(id)
self.learnware_zip_list.pop(id)
self.learnware_folder_list.pop(id)
self.use_flags.pop(id)
self.dbops.delete_learnware(id=id)

return True

def update_learnware(self, id: str, zip_path: str = None, semantic_spec: dict = None, check_status: int = None):
"""Update learnware with zip_path, semantic_specification and check_status

Parameters
----------
id : str
Learnware id
zip_path : str, optional
Filepath for learnware model, a zipped file.
semantic_spec : dict, optional
semantic_spec for new learnware, in dictionary format.
check_status : int, optional
A flag indicating whether the learnware is usable.

Returns
-------
int
The final learnware check_status.
"""
if check_status == BaseChecker.INVALID_LEARNWARE:
logger.warning("Learnware is invalid!")
return BaseChecker.INVALID_LEARNWARE

if zip_path is None and semantic_spec is None and check_status is None:
logger.warning(
"At least one of 'zip_path', 'semantic_spec' and 'check_status' should not be None when update learnware"
)
return BaseChecker.INVALID_LEARNWARE

# Update semantic_specification
learnware_zippath = self.learnware_zip_list[id] if zip_path is None else zip_path
semantic_spec = (
self.learnware_list[id].get_specification().get_semantic_spec() if semantic_spec is None else semantic_spec
)
self.dbops.update_learnware_semantic_specification(id, semantic_spec)

# Update zip path
target_zip_dir = self.learnware_zip_list[id]
target_folder_dir = self.learnware_folder_list[id]
if zip_path is not None:
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(zip_path, "r") as z_file:
z_file.extractall(tempdir)

try:
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=tempdir
)
except Exception:
return BaseChecker.INVALID_LEARNWARE

if new_learnware is None:
return BaseChecker.INVALID_LEARNWARE

copyfile(zip_path, target_zip_dir)
with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)

# Update check_status
self.use_flags[id] = self.use_flags[id] if check_status is None else check_status
self.dbops.update_learnware_use_flag(id, self.use_flags[id])

# Update learnware list
self.learnware_list[id] = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)

return self.use_flags[id]

def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Search learnware by id or list of ids.

Parameters
----------
ids : Union[str, List[str]]
Give a id or a list of ids
str: id of targer learware
List[str]: A list of ids of target learnwares

Returns
-------
Union[Learnware, List[Learnware]]
Return target learnware or list of target learnwares.
None for Learnware NOT Found.
"""
if isinstance(ids, list):
ret = []
for id in ids:
if id in self.learnware_list:
ret.append(self.learnware_list[id])
else:
logger.warning("Learnware ID '%s' NOT Found!" % (id))
ret.append(None)
return ret
else:
try:
return self.learnware_list[ids]
except:
logger.warning("Learnware ID '%s' NOT Found!" % (ids))
return None

def get_learnware_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Get Zipped Learnware file by id

Parameters
----------
ids : Union[str, List[str]]
Give a id or a list of ids
str: id of targer learware
List[str]: A list of ids of target learnwares

Returns
-------
Union[Learnware, List[Learnware]]
Return the path for target learnware or list of path.
None for Learnware NOT Found.
"""
if isinstance(ids, list):
ret = []
for id in ids:
if id in self.learnware_zip_list:
ret.append(self.learnware_zip_list[id])
else:
logger.warning("Learnware ID '%s' NOT Found!" % (id))
ret.append(None)
return ret
else:
try:
return self.learnware_zip_list[ids]
except:
logger.warning("Learnware ID '%s' NOT Found!" % (ids))
return None

def get_learnware_ids(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.keys())
else:
return list(self.learnware_list.keys())[:top]

def get_learnwares(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.values())
else:
return list(self.learnware_list.values())[:top]

def __len__(self):
return len(self.learnware_list)

+ 637
- 0
learnware/market/easy2/searcher.py View File

@@ -0,0 +1,637 @@
import torch
import numpy as np
from rapidfuzz import fuzz
from cvxopt import solvers, matrix
from typing import Tuple, List

from .organizer import EasyOrganizer
from ..base import BaseUserInfo, BaseSearcher
from ...learnware import Learnware
from ...specification import RKMEStatSpecification
from ...logger import get_module_logger

logger = get_module_logger("easy_seacher")


class EasyExactSemanticSearcher(BaseSearcher):
def _match_semantic_spec(self, semantic_spec1, semantic_spec2):
"""
semantic_spec1: semantic spec input by user
semantic_spec2: semantic spec in database
"""
if semantic_spec1.keys() != semantic_spec2.keys():
# sematic spec in database may contain more keys than user input
pass

name2 = semantic_spec2["Name"]["Values"].lower()
description2 = semantic_spec2["Description"]["Values"].lower()

for key in semantic_spec1.keys():
v1 = semantic_spec1[key].get("Values", "")
v2 = semantic_spec2[key].get("Values", "")

if len(v1) == 0:
# user input is empty, no need to search
continue

if key in ("Name", "Description"):
v1 = v1.lower()
if v1 not in name2 and v1 not in description2:
return False
pass
else:
if len(v2) == 0:
# user input contains some key that is not in database
return False

if semantic_spec1[key]["Type"] == "Class":
if isinstance(v1, list):
v1 = v1[0]
if isinstance(v2, list):
v2 = v2[0]
if v1 != v2:
return False
elif semantic_spec1[key]["Type"] == "Tag":
if not (set(v1) & set(v2)):
return False
pass
pass
pass

return True

def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]:
match_learnwares = []
for learnware in learnware_list:
learnware_semantic_spec = learnware.get_specification().get_semantic_spec()
user_semantic_spec = user_info.get_semantic_spec()
if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec):
match_learnwares.append(learnware)
logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list)))
return match_learnwares


class EasyFuzzSemanticSearcher(BaseSearcher):
def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool:
"""Judge if tags of two semantic specs are consistent

Parameters
----------
semantic_spec1 :
semantic spec input by user
semantic_spec2 :
semantic spec in database

Returns
-------
bool
consistent (True) or not consistent (False)
"""
for key in semantic_spec1.keys():
v1 = semantic_spec1[key].get("Values", "")
v2 = semantic_spec2[key].get("Values", "")

if len(v1) == 0:
# user input is empty, no need to search
continue

if key not in "Name":
if len(v2) == 0:
# user input contains some key that is not in database
return False

if semantic_spec1[key]["Type"] == "Class":
if isinstance(v1, list):
v1 = v1[0]
if isinstance(v2, list):
v2 = v2[0]
if v1 != v2:
return False
elif semantic_spec1[key]["Type"] == "Tag":
if not (set(v1) & set(v2)):
return False
return True

def __call__(
self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0
) -> List[Learnware]:
"""Search learnware by fuzzy matching of semantic spec

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares
user_info : BaseUserInfo
user_info contains semantic_spec
max_num : int, optional
maximum number of learnwares returned, by default 50000
min_score : float, optional
Minimum fuzzy matching score of learnwares returned, by default 30.0

Returns
-------
List[Learnware]
The list of returned learnwares
"""
matched_learnware_tag = []
final_result = []
user_semantic_spec = user_info.get_semantic_spec()

for learnware in learnware_list:
learnware_semantic_spec = learnware.get_specification().get_semantic_spec()
if self._match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec):
matched_learnware_tag.append(learnware)

if len(matched_learnware_tag) > 0:
if "Name" in user_semantic_spec:
name_user = user_semantic_spec["Name"]["Values"].lower()
if len(name_user) > 0:
# Exact search
name_list = [
learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower()
for learnware in matched_learnware_tag
]
des_list = [
learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower()
for learnware in matched_learnware_tag
]

matched_learnware_exact = []
for i in range(len(name_list)):
if name_user in name_list[i] or name_user in des_list[i]:
matched_learnware_exact.append(matched_learnware_tag[i])

if len(matched_learnware_exact) == 0:
# Fuzzy search
matched_learnware_fuzz, fuzz_scores = [], []
for i in range(len(name_list)):
score_name = fuzz.partial_ratio(name_user, name_list[i])
score_des = fuzz.partial_ratio(name_user, des_list[i])
final_score = max(score_name, score_des)
if final_score >= min_score:
matched_learnware_fuzz.append(matched_learnware_tag[i])
fuzz_scores.append(final_score)

# Sort by score
sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[
:max_num
]
final_result = [matched_learnware_fuzz[idx] for idx in sort_idx]
else:
final_result = matched_learnware_exact
else:
final_result = matched_learnware_tag
else:
final_result = matched_learnware_tag

logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list)))
return final_result


class EasyTableSearcher(BaseSearcher):
def _convert_dist_to_score(
self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92
) -> List[float]:
"""Convert mmd dist list into min_max score list

Parameters
----------
dist_list : List[float]
The list of mmd distances from learnware rkmes to user rkme
dist_epsilon: float
The paramter for converting mmd dist to score
min_score: float
The minimum score for maximum returned score

Returns
-------
List[float]
The list of min_max scores of each learnware
"""
if len(dist_list) == 0:
return []

min_dist, max_dist = min(dist_list), max(dist_list)
if min_dist == max_dist:
return [1 for dist in dist_list]
else:
max_score = (max_dist - min_dist) / (max_dist - dist_epsilon)

if min_dist < dist_epsilon:
dist_epsilon = min_dist
elif max_score < min_score:
dist_epsilon = max_dist - (max_dist - min_dist) / min_score

return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list]

def _calculate_rkme_spec_mixture_weight(
self,
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None,
) -> Tuple[List[float], float]:
"""Calculate mixture weight for the learnware_list based on a user's rkme

Parameters
----------
learnware_list : List[Learnware]
A list of existing learnwares
user_rkme : RKMEStatSpecification
User RKME statistical specification
intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None
intermediate_C : np.ndarray, optional
Intermediate inner product vector C, by default None

Returns
-------
Tuple[List[float], float]
The first is the list of mixture weights
The second is the mmd dist between the mixture of learnware rkmes and the user's rkme
"""
learnware_num = len(learnware_list)
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
]

if type(intermediate_K) == np.ndarray:
K = intermediate_K
else:
K = np.zeros((learnware_num, learnware_num))
for i in range(K.shape[0]):
K[i, i] = RKME_list[i].inner_prod(RKME_list[i])
for j in range(i + 1, K.shape[0]):
K[i, j] = K[j, i] = RKME_list[i].inner_prod(RKME_list[j])

if type(intermediate_C) == np.ndarray:
C = intermediate_C
else:
C = np.zeros((learnware_num, 1))
for i in range(C.shape[0]):
C[i, 0] = user_rkme.inner_prod(RKME_list[i])

K = torch.from_numpy(K).double().to(user_rkme.device)
C = torch.from_numpy(C).double().to(user_rkme.device)

# beta can be negative
# weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C

# beta must be nonnegative
n = K.shape[0]
P = matrix(K.cpu().numpy())
q = matrix(-C.cpu().numpy())
G = matrix(-np.eye(n))
h = matrix(np.zeros((n, 1)))
A = matrix(np.ones((1, n)))
b = matrix(np.ones((1, 1)))
solvers.options["show_progress"] = False
sol = solvers.qp(P, q, G, h, A, b)
weight = np.array(sol["x"])
weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device)
score = user_rkme.inner_prod(user_rkme) + 2 * sol["primal objective"]

return weight.detach().cpu().numpy().reshape(-1), score

def _calculate_intermediate_K_and_C(
self,
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Incrementally update the values of intermediate_K and intermediate_C

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares up till now
user_rkme : RKMEStatSpecification
User RKME statistical specification
intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None
intermediate_C : np.ndarray, optional
Intermediate inner product vector C, by default None

Returns
-------
Tuple[np.ndarray, np.ndarray]
The first is the intermediate value of K
The second is the intermediate value of C
"""
num = intermediate_K.shape[0] - 1
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
]
for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1])
return intermediate_K, intermediate_C

def _search_by_rkme_spec_mixture_auto(
self,
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
max_search_num: int,
weight_cutoff: float = 0.98,
) -> Tuple[float, List[float], List[Learnware]]:
"""Select learnwares based on a total mixture ratio, then recalculate their mixture weights

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
User RKME statistical specification
max_search_num : int
The maximum number of the returned learnwares
weight_cutoff : float, optional
The ratio for selecting out the mose relevant learnwares, by default 0.9

Returns
-------
Tuple[float, List[float], List[Learnware]]
The first is the mixture mmd dist
The second is the list of weight
The third is the list of Learnware
"""
learnware_num = len(learnware_list)
if learnware_num == 0:
return [], []
if learnware_num < max_search_num:
logger.warning("Available Learnware num less than search_num!")
max_search_num = learnware_num

weight, _ = self._calculate_rkme_spec_mixture_weight(learnware_list, user_rkme)
sort_by_weight_idx_list = sorted(range(learnware_num), key=lambda k: weight[k], reverse=True)

weight_sum = 0
mixture_list = []
for idx in sort_by_weight_idx_list:
weight_sum += weight[idx]
if weight_sum <= weight_cutoff:
mixture_list.append(learnware_list[idx])
else:
break

if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification"))
else:
if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num]
mixture_weight, mmd_dist = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme)

return mmd_dist, mixture_weight, mixture_list

def _filter_by_rkme_spec_single(
self,
sorted_score_list: List[float],
learnware_list: List[Learnware],
filter_score: float = 0.5,
min_num: int = 15,
) -> Tuple[List[float], List[Learnware]]:
"""Filter search result of _search_by_rkme_spec_single

Parameters
----------
sorted_score_list : List[float]
The list of score transformed by mmd dist
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
filter_score: float
The learnware whose score is lower than filter_score will be filtered
min_num: int
The minimum number of returned learnwares

Returns
-------
Tuple[List[float], List[Learnware]]
the first is the list of score
the second is the list of Learnware
"""
idx = min(min_num, len(learnware_list))
while idx < len(learnware_list):
if sorted_score_list[idx] < filter_score:
break
idx = idx + 1
return sorted_score_list[:idx], learnware_list[:idx]

def _filter_by_rkme_spec_dimension(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
) -> List[Learnware]:
"""Filter learnwares whose rkme dimension different from user_rkme

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
User RKME statistical specification

Returns
-------
List[Learnware]
Learnwares whose rkme dimensions equal user_rkme in user_info
"""
filtered_learnware_list = []
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:])

for learnware in learnware_list:
rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification")
rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware)

return filtered_learnware_list

def _search_by_rkme_spec_mixture_greedy(
self,
learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
max_search_num: int,
score_cutoff: float = 0.001,
) -> Tuple[float, List[float], List[Learnware]]:
"""Greedily match learnwares such that their mixture become closer and closer to user's rkme

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
User RKME statistical specification
max_search_num : int
The maximum number of the returned learnwares
score_cutof: float
The minimum mmd dist as threshold to stop further rkme_spec matching

Returns
-------
Tuple[float, List[float], List[Learnware]]
The first is the mixture mmd dist
The second is the list of weight
The third is the list of Learnware
"""
learnware_num = len(learnware_list)
if learnware_num == 0:
return None, [], []
if learnware_num < max_search_num:
logger.warning("Available Learnware num less than search_num!")
max_search_num = learnware_num

flag_list = [0 for _ in range(learnware_num)]
mixture_list, mmd_dist = [], None
intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1))

for k in range(max_search_num):
idx_min, score_min = -1, -1
weight_min = None
mixture_list.append(None)

if k != 0:
intermediate_K = np.c_[intermediate_K, np.zeros((k, 1))]
intermediate_K = np.r_[intermediate_K, np.zeros((1, k + 1))]
intermediate_C = np.r_[intermediate_C, np.zeros((1, 1))]

for idx in range(len(learnware_list)):
if flag_list[idx] == 0:
mixture_list[-1] = learnware_list[idx]
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(
mixture_list, user_rkme, intermediate_K, intermediate_C
)
weight, score = self._calculate_rkme_spec_mixture_weight(
mixture_list, user_rkme, intermediate_K, intermediate_C
)
if idx_min == -1 or score < score_min:
idx_min, score_min, weight_min = idx, score, weight

mmd_dist = score_min
mixture_list[-1] = learnware_list[idx_min]
if score_min < score_cutoff:
break
else:
flag_list[idx_min] = 1
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(
mixture_list, user_rkme, intermediate_K, intermediate_C
)

return mmd_dist, weight_min, mixture_list

def _search_by_rkme_spec_single(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme

Parameters
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user RKME statistical specification

Returns
-------
Tuple[List[float], List[Learnware]]
the first is the list of mmd dist
the second is the list of Learnware
both lists are sorted by mmd dist
"""
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
]
mmd_dist_list = []
for RKME in RKME_list:
mmd_dist = RKME.dist(user_rkme)
mmd_dist_list.append(mmd_dist)

sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k])
sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list]
sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list]

return sorted_dist_list, sorted_learnware_list

def __call__(
self,
learnware_list: List[Learnware],
user_info: BaseUserInfo,
max_search_num: int = 5,
search_method: str = "greedy",
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme)
logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}")

sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme)
if search_method == "auto":
mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto(
learnware_list, user_rkme, max_search_num
)
elif search_method == "greedy":
mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy(
learnware_list, user_rkme, max_search_num
)
else:
logger.warning("f{search_method} not supported!")
mixture_dist = None
weight_list = []
mixture_learnware_list = []

if mixture_dist is None:
sorted_score_list = self._convert_dist_to_score(sorted_dist_list)
mixture_score = None
else:
merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist])
sorted_score_list = merge_score_list[:-1]
mixture_score = merge_score_list[-1]

logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}")
# filter learnware with low score
sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single(
sorted_score_list, single_learnware_list
)

logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}")
return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list


class EasySearcher(BaseSearcher):
def __init__(self, organizer: EasyOrganizer = None):
super(EasySearcher, self).__init__(organizer)
self.semantic_searcher = EasyFuzzSemanticSearcher(organizer)
self.table_searcher = EasyTableSearcher(organizer)

def reset(self, organizer):
self.learnware_oganizer = organizer
self.semantic_searcher.reset(organizer)
self.table_searcher.reset(organizer)

def __call__(
self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy"
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
"""Search learnwares based on user_info

Parameters
----------
user_info : BaseUserInfo
user_info contains semantic_spec and stat_info
max_search_num : int
The maximum number of the returned learnwares

Returns
-------
Tuple[List[float], List[Learnware], float, List[Learnware]]
the first is the sorted list of rkme dist
the second is the sorted list of Learnware (single) by the rkme dist
the third is the score of Learnware (mixture)
the fourth is the list of Learnware (mixture), the size is search_num
"""
learnware_list = self.learnware_oganizer.get_learnwares()
learnware_list = self.semantic_searcher(learnware_list, user_info)

if len(learnware_list) == 0:
return [], [], 0.0, []
elif "RKMEStatSpecification" in user_info.stat_info:
return self.table_searcher(learnware_list, user_info, max_search_num, search_method)
else:
return None, learnware_list, 0.0, None

+ 1
- 0
learnware/market/evolve/__init__.py View File

@@ -0,0 +1 @@
from .organizer import EvolvedOrganizer

learnware/market/evolve.py → learnware/market/evolve/organizer.py View File

@@ -1,21 +1,18 @@
from typing import Tuple, Any, List, Union, Dict
from typing import List

from .base import BaseMarket
from ..learnware import Learnware
from ..specification import BaseStatSpecification
from ..easy2.organizer import EasyOrganizer
from ...learnware import Learnware
from ...specification import BaseStatSpecification
from ...logger import get_module_logger

logger = get_module_logger("evolve_organizer")

class EvolvedMarket(BaseMarket):
"""Organize learnwares and enable them to continuously evolve

Parameters
----------
BaseMarket : _type_
Basic market version
"""
class EvolvedOrganizer(EasyOrganizer):
"""Organize learnwares and enable them to continuously evolve"""

def __init__(self, *args, **kwargs):
super(EvolvedMarket, self).__init__(*args, **kwargs)
super(EvolvedOrganizer, self).__init__(*args, **kwargs)

def generate_new_stat_specification(self, learnware: Learnware) -> BaseStatSpecification:
"""Generate new statistical specification for learnwares

+ 1
- 0
learnware/market/evolve_anchor/__init__.py View File

@@ -0,0 +1 @@
from .organizer import EvolvedAnchoredOrganizer

learnware/market/evolve_anchor.py → learnware/market/evolve_anchor/organizer.py View File

@@ -1,22 +1,17 @@
from typing import Tuple, Any, List, Union, Dict
from typing import List

from .anchor import AnchoredUserInfo, AnchoredMarket
from .evolve import EvolvedMarket
from ..evolve import EvolvedOrganizer
from ..anchor import AnchoredOrganizer, AnchoredUserInfo
from ...logger import get_module_logger

logger = get_module_logger("evolve_anchor_organizer")

class EvolvedAnchoredMarket(AnchoredMarket, EvolvedMarket):
"""Organize learnwares with anchors and enable them to continuously evolve

Parameters
----------
AnchoredMarket : _type_
Market version with anchors
EvolvedMarket : _type_
Market version with evolved learnwares
"""
class EvolvedAnchoredOrganizer(AnchoredOrganizer, EvolvedOrganizer):
"""Organize learnwares and enable them to continuously evolve"""

def __init__(self, *args, **kwargs):
super(EvolvedAnchoredMarket, self).__init__(*args, **kwargs)
AnchoredOrganizer.__init__(self, *args, **kwargs)

def evolve_anchor_learnware_list(self, anchor_id_list: List[str]):
"""Enable anchor learnwares to evolve, e.g., new stat_spec

+ 1
- 0
learnware/market/hetergeneous/__init__.py View File

@@ -0,0 +1 @@
from .organizer import MappingFunction, HeterogeneousOrganizer

learnware/market/heterogeneous_feature.py → learnware/market/hetergeneous/organizer.py View File

@@ -1,8 +1,8 @@
import numpy as np
from typing import Tuple, Any, List, Union, Dict
from typing import List

from .evolve import EvolvedMarket
from ..learnware import Learnware
from ..evolve.organizer import EvolvedOrganizer
from ...learnware import Learnware


class MappingFunction:
@@ -25,17 +25,11 @@ class MappingFunction:
pass


class HeterogeneousFeatureMarket(EvolvedMarket):
"""Organize learnwares with heterogeneous feature spaces

Parameters
----------
EvolvedMarket : _type_
Market version with evolved learnwares
"""
class HeterogeneousOrganizer(EvolvedOrganizer):
"""Organize learnwares with heterogeneous feature spaces, organizer version with evolved learnwares"""

def __init__(self, *args, **kwargs):
super(HeterogeneousFeatureMarket, self).__init__(*args, **kwargs)
super(HeterogeneousOrganizer, self).__init__(*args, **kwargs)
self.mapping_function_list = {}

def _mapping_function_list_initialization(self, learnware_list: List[Learnware]):

+ 20
- 0
learnware/market/module.py View File

@@ -0,0 +1,20 @@
from .base import LearnwareMarket
from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatisticalChecker

MARKET_CONFIG = {
"easy": {
"organizer": EasyOrganizer(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatisticalChecker()],
}
}


def instatiate_learnware_market(market_id, name="easy", **kwargs):
return LearnwareMarket(
market_id=market_id,
organizer=MARKET_CONFIG[name]["organizer"],
searcher=MARKET_CONFIG[name]["searcher"],
checker_list=MARKET_CONFIG[name]["checker_list"],
**kwargs
)

+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,4 +1,4 @@
from .utils import generate_stat_spec
from .base import Specification, BaseStatSpecification
from .rkme import RKMEStatSpecification
from .image import RKMEImageStatSpecification
from .table import RKMEStatSpecification

+ 9
- 0
learnware/specification/base.py View File

@@ -6,6 +6,15 @@ from typing import Dict
class BaseStatSpecification:
"""The Statistical Specification Interface, which provide save and load method"""

def __init__(self, type: str):
"""initilize the type of stats specification
Parameters
----------
type : str
the type of the stats specification
"""
self.type = type

def generate_stat_spec_from_data(self, **kwargs):
"""Construct statistical specification from raw dataset
- kwargs may include the feature, label and model


+ 1
- 0
learnware/specification/table/__init__.py View File

@@ -0,0 +1 @@
from .rkme import RKMEStatSpecification

learnware/specification/rkme.py → learnware/specification/table/rkme.py View File

@@ -20,8 +20,8 @@ try:
except ImportError:
_FAISS_INSTALLED = False

from .base import BaseStatSpecification
from ..logger import get_module_logger
from ..base import BaseStatSpecification
from ...logger import get_module_logger

logger = get_module_logger("rkme")

@@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification):
torch.cuda.empty_cache()
self.device = choose_device(cuda_idx=cuda_idx)
setup_seed(0)
super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__)

def get_beta(self) -> np.ndarray:
"""Move beta(RKME weights) back to memory accessible to the CPU.
@@ -427,6 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy()
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"
rkme_to_save["type"] = self.type
json.dump(
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),

+ 1
- 1
learnware/specification/utils.py View File

@@ -4,7 +4,7 @@ import pandas as pd
from typing import Union

from .base import BaseStatSpecification
from .rkme import RKMEStatSpecification
from .table import RKMEStatSpecification
from ..config import C




+ 10
- 0
tests/test_market/learnware_example/README.md View File

@@ -0,0 +1,10 @@
## How to Generate Environment Yaml

* create env config for conda:
```shell
conda env export | grep -v "^prefix: " > environment.yml
```
* recover env from config
```
conda env create -f environment.yml
```

+ 27
- 0
tests/test_market/learnware_example/environment.yaml View File

@@ -0,0 +1,27 @@
name: learnware_example_env
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.01.10=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1t=h7f8727e_0
- pip=23.0.1=py38h06a4308_0
- python=3.8.16=h7a1cb2a_3
- readline=8.2=h5eee18b_0
- setuptools=66.0.0=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.2.10=h5eee18b_1
- zlib=1.2.13=h5eee18b_0
- pip:
- joblib==1.2.0
- learnware==0.0.1.99
- numpy==1.19.5

+ 8
- 0
tests/test_market/learnware_example/example.yaml View File

@@ -0,0 +1,8 @@
model:
class_name: SVM
kwargs: {}
stat_specifications:
- module_path: learnware.specification
class_name: RKMEStatSpecification
file_name: svm.json
kwargs: {}

+ 20
- 0
tests/test_market/learnware_example/example_init.py View File

@@ -0,0 +1,20 @@
import os
import joblib
import numpy as np
from learnware.model import BaseModel


class SVM(BaseModel):
def __init__(self):
super(SVM, self).__init__(input_shape=(64,), output_shape=(10,))
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "svm.pkl"))

def fit(self, X: np.ndarray, y: np.ndarray):
pass

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict_proba(X)

def finetune(self, X: np.ndarray, y: np.ndarray):
pass

+ 205
- 0
tests/test_market/test_easy.py View File

@@ -0,0 +1,205 @@
import sys
import unittest
import os
import copy
import joblib
import zipfile
import numpy as np
from sklearn import svm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from shutil import copyfile, rmtree

import learnware
from learnware.market import instatiate_learnware_market, BaseUserInfo
import learnware.specification as specification

curr_root = os.path.dirname(os.path.abspath(__file__))

user_semantic = {
"Data": {"Values": ["Image"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Library": {"Values": ["Scikit-learn"], "Type": "Class"},
"Scenario": {"Values": ["Education"], "Type": "Tag"},
"Description": {"Values": "", "Type": "String"},
"Name": {"Values": "", "Type": "String"},
"Output": {
"Dimension": 10,
"Description": {
"0": "the probability of the label is zero",
},
},
}


class TestMarket(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
np.random.seed(2023)
learnware.init()

def _init_learnware_market(self):
"""initialize learnware market"""
easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True)
return easy_market

def test_prepare_learnware_randomly(self, learnware_num=5):
self.zip_path_list = []
X, y = load_digits(return_X_y=True)

for i in range(learnware_num):
dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
os.makedirs(dir_path, exist_ok=True)

print("Preparing Learnware: %d" % (i))

data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
clf = svm.SVC(kernel="linear", probability=True)
clf.fit(data_X, data_y)

joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))

spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec.save(os.path.join(dir_path, "svm.json"))

init_file = os.path.join(dir_path, "__init__.py")
copyfile(
os.path.join(curr_root, "learnware_example/example_init.py"), init_file
) # cp example_init.py init_file

yaml_file = os.path.join(dir_path, "learnware.yaml")
copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file

env_file = os.path.join(dir_path, "environment.yaml")
copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file)

zip_file = dir_path + ".zip"
# zip -q -r -j zip_file dir_path
with zipfile.ZipFile(zip_file, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
zip_info = zipfile.ZipInfo(filename)
zip_info.compress_type = zipfile.ZIP_STORED
with open(file_path, "rb") as file:
zip_obj.writestr(zip_info, file.read())

rmtree(dir_path) # rm -r dir_path

self.zip_path_list.append(zip_file)

def test_upload_delete_learnware(self, learnware_num=5, delete=True):
easy_market = self._init_learnware_market()
self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num

print("Total Item:", len(easy_market))
assert len(easy_market) == 0, f"The market should be empty!"

for idx, zip_path in enumerate(self.zip_path_list):
semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
easy_market.add_learnware(zip_path, semantic_spec)

print("Total Item:", len(easy_market))
assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

curr_inds = easy_market.get_learnware_ids()
print("Available ids After Uploading Learnwares:", curr_inds)
assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

if delete:
for learnware_id in curr_inds:
easy_market.delete_learnware(learnware_id)
self.learnware_num -= 1
assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

curr_inds = easy_market.get_learnware_ids()
print("Available ids After Deleting Learnwares:", curr_inds)
assert len(curr_inds) == 0, f"The market should be empty!"

return easy_market

def test_search_semantics(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))
assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"

semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"

user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = easy_market.search_learnware(user_info)

print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
assert len(single_learnware_list) == 1, f"Exact semantic search failed!"
for learnware in single_learnware_list:
semantic_spec1 = learnware.get_specification().get_semantic_spec()
print("Choose learnware:", learnware.id, semantic_spec1)
assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"

semantic_spec["Name"]["Values"] = "laernwaer"
user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = easy_market.search_learnware(user_info)

print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!"
for learnware in single_learnware_list:
semantic_spec1 = learnware.get_specification().get_semantic_spec()
print("Choose learnware:", learnware.id, semantic_spec1)

def test_stat_search(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))

test_folder = os.path.join(curr_root, "test_stat")

for idx, zip_path in enumerate(self.zip_path_list):
unzip_dir = os.path.join(test_folder, f"{idx}")

# unzip -o -q zip_path -d unzip_dir
if os.path.exists(unzip_dir):
rmtree(unzip_dir)
os.makedirs(unzip_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info)

assert len(single_learnware_list) == self.learnware_num, f"Statistical search failed!"
print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")

rmtree(test_folder) # rm -r test_folder


def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestMarket("test_prepare_learnware_randomly"))
_suite.addTest(TestMarket("test_upload_delete_learnware"))
_suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_stat_search"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

+ 31
- 0
tests/test_specification/test_rkme.py View File

@@ -0,0 +1,31 @@
import os
import json
import unittest
import tempfile
import numpy as np

import learnware
import learnware.specification as specification
from learnware.specification import RKMEStatSpecification


class TestRKME(unittest.TestCase):
def test_rkme(self):
X = np.random.uniform(-10000, 10000, size=(5000, 200))
rkme = specification.utils.generate_rkme_spec(X)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
rkme.save(rkme_path)

with open(rkme_path, "r") as f:
data = json.load(f)
assert data["type"] == "RKMEStatSpecification"

rkme2 = RKMEStatSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMEStatSpecification"


if __name__ == "__main__":
unittest.main()

+ 1
- 1
tests/test_workflow/test_workflow.py View File

@@ -155,7 +155,7 @@ class TestAllWorkflow(unittest.TestCase):
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
(


Loading…
Cancel
Save