Browse Source

Merge branch 'main' of https://github.com/Learnware-LAMDA/Learnware into main

tags/v0.3.2
bxdd 2 years ago
parent
commit
a1fd45999b
10 changed files with 190 additions and 85 deletions
  1. +4
    -4
      examples/dataset_text_workflow/main.py
  2. +1
    -0
      learnware/client/container.py
  3. +6
    -5
      learnware/client/learnware_client.py
  4. +1
    -0
      learnware/client/utils.py
  5. +1
    -1
      learnware/market/__init__.py
  6. +107
    -47
      learnware/market/base.py
  7. +60
    -21
      learnware/market/easy2/organizer.py
  8. +7
    -4
      learnware/market/easy2/searcher.py
  9. +1
    -1
      learnware/market/module.py
  10. +2
    -2
      tests/test_market/test_easy.py

+ 4
- 4
examples/dataset_text_workflow/main.py View File

@@ -9,7 +9,7 @@ from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningR
import time
import pickle

from learnware.market import instatiate_learnware_market, BaseUserInfo
from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market import database_ops
from learnware.learnware import Learnware
import learnware.specification as specification
@@ -120,7 +120,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo


def prepare_market():
text_market = instatiate_learnware_market(market_id="sst2", rebuild=True)
text_market = instantiate_learnware_market(market_id="sst2", rebuild=True)
try:
rmtree(learnware_pool_dir)
except:
@@ -144,10 +144,10 @@ def prepare_market():

def test_search(gamma=0.1, load_market=True):
if load_market:
text_market = instatiate_learnware_market(market_id="sst2")
text_market = instantiate_learnware_market(market_id="sst2")
else:
prepare_market()
text_market = instatiate_learnware_market(market_id="sst2")
text_market = instantiate_learnware_market(market_id="sst2")
logger.info("Number of items in the market: %d" % len(text_market))

select_list = []


+ 1
- 0
learnware/client/container.py View File

@@ -337,6 +337,7 @@ class ModelDockerContainer(ModelContainer):
"install",
"-r",
f"{requirements_path_filter}",
"--no-dependencies",
]
)
)


+ 6
- 5
learnware/client/learnware_client.py View File

@@ -92,7 +92,7 @@ class LearnwareClient:

@require_login
def upload_learnware(self, learnware_zip_path, semantic_specification):
assert self._check_semantic_specification(semantic_specification)
assert self._check_semantic_specification(semantic_specification), "Semantic specification check failed!"
file_hash = compute_file_hash(learnware_zip_path)
url_upload = f"{self.host}/user/chunked_upload"

@@ -276,8 +276,7 @@ class LearnwareClient:
response = requests.get(url, headers=self.headers)
result = response.json()
semantic_conf = result["data"]["semantic_specification"]

return semantic_conf[key]["Values"]
return semantic_conf[key.value]["Values"]

def load_learnware(
self,
@@ -386,14 +385,16 @@ class LearnwareClient:

@staticmethod
def _check_semantic_specification(semantic_spec):
return EasySemanticChecker.check_semantic_spec(semantic_spec)
return EasySemanticChecker.check_semantic_spec(semantic_spec) != EasySemanticChecker.INVALID_LEARNWARE

@staticmethod
def check_learnware(learnware_zip_path, semantic_specification=None):
semantic_specification = (
get_semantic_specification() if semantic_specification is None else semantic_specification
)
LearnwareClient._check_semantic_specification(semantic_specification)
assert LearnwareClient._check_semantic_specification(
semantic_specification
), "Semantic specification check failed!"

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file:


+ 1
- 0
learnware/client/utils.py View File

@@ -76,6 +76,7 @@ def install_environment(zip_path, conda_env):
"install",
"-r",
f"{requirements_path_filter}",
"--no-dependencies",
]
)
else:


+ 1
- 1
learnware/market/__init__.py View File

@@ -6,4 +6,4 @@ from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatist
from .hetergeneous import HeterogeneousOrganizer, MappingFunction

from .easy import EasyMarket
from .module import instatiate_learnware_market
from .module import instantiate_learnware_market

+ 107
- 47
learnware/market/base.py View File

@@ -1,7 +1,7 @@
from __future__ import annotations

import zipfile
import tempfile


from typing import Tuple, Any, List, Union
from ..learnware import Learnware, get_learnware_from_dirpath
from ..logger import get_module_logger
@@ -47,10 +47,10 @@ class LearnwareMarket:

def __init__(
self,
market_id: str = None,
organizer: "BaseOrganizer" = None,
searcher: "BaseSearcher" = None,
checker_list: List["BaseChecker"] = None,
market_id: str = "default",
organizer: BaseOrganizer = None,
searcher: BaseSearcher = None,
checker_list: List[BaseChecker] = None,
rebuild=False,
):
self.market_id = market_id
@@ -70,29 +70,27 @@ class LearnwareMarket:
def reload_market(self, **kwargs) -> bool:
self.learnware_organizer.reload_market(**kwargs)

def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
try:
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)

pending_learnware = get_learnware_from_dirpath(
id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir
)

final_status = BaseChecker.INVALID_LEARNWARE
checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names

for name in checker_names:
checker = self.learnware_checker[name]
check_status = checker(pending_learnware)
final_status = max(final_status, check_status)

if check_status == BaseChecker.INVALID_LEARNWARE:
return BaseChecker.INVALID_LEARNWARE

return final_status

final_status = BaseChecker.NONUSABLE_LEARNWARE
if len(checker_names):
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)

pending_learnware = get_learnware_from_dirpath(
id="pending", semantic_spec=semantic_spec, learnware_dirpath=tempdir
)
checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names

for name in checker_names:
checker = self.learnware_checker[name]
check_status = checker(pending_learnware)
final_status = max(final_status, check_status)

if check_status == BaseChecker.INVALID_LEARNWARE:
return BaseChecker.INVALID_LEARNWARE
return final_status
except Exception as err:
logger.warning(f"Check learnware failed! Due to {err}.")
return BaseChecker.INVALID_LEARNWARE
@@ -122,8 +120,25 @@ class LearnwareMarket:
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]:
return self.learnware_searcher(user_info, **kwargs)
def search_learnware(
self, user_info: BaseUserInfo, check_status: int = None, **kwargs
) -> Tuple[Any, List[Learnware]]:
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
user_info : BaseUserInfo
User information for searching learnwares
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status

Returns
-------
Tuple[Any, List[Learnware]]
Search results
"""
return self.learnware_searcher(user_info, check_status, **kwargs)

def delete_learnware(self, id: str, **kwargs) -> bool:
return self.learnware_organizer.delete_learnware(id, **kwargs)
@@ -131,8 +146,8 @@ class LearnwareMarket:
def update_learnware(
self,
id: str,
zip_path: str,
semantic_spec: dict,
zip_path: str = None,
semantic_spec: dict = None,
checker_names: List[str] = None,
check_status: int = None,
**kwargs,
@@ -157,6 +172,12 @@ class LearnwareMarket:
int
The final learnware check_status.
"""
zip_path = self.get_learnware_path_by_ids(id) if zip_path is None else zip_path
semantic_spec = (
self.get_learnware_by_ids(id).get_specification().get_semantic_spec()
if semantic_spec is None
else semantic_spec
)
update_status = self.check_learnware(zip_path, semantic_spec, checker_names)
check_status = (
update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status
@@ -166,14 +187,44 @@ class LearnwareMarket:
id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def get_learnware_ids(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnware_ids(top, **kwargs)
def get_learnware_ids(self, top: int = None, check_status: int = None, **kwargs) -> List[str]:
"""get the list of learnware ids

Parameters
----------
top : int, optional
The first top element to return, by default None
check_status : int, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Raises
------
List[str]
the first top ids
"""
return self.learnware_organizer.get_learnware_ids(top, check_status, **kwargs)

def get_learnwares(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnwares(top, **kwargs)
def get_learnwares(self, top: int = None, check_status: int = None, **kwargs) -> List[Learnware]:
"""get the list of learnwares

Parameters
----------
top : int, optional
The first top element to return, by default None
check_status : int, optional
- None: return all learnwares
- Others: return learnwares with check_status

Raises
------
List[Learnware]
the first top learnwares
"""
return self.learnware_organizer.get_learnwares(top, check_status, **kwargs)

def get_learnware_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
raise self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs)
return self.learnware_organizer.get_learnware_path_by_ids(ids, **kwargs)

def get_learnware_by_ids(self, id: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
return self.learnware_organizer.get_learnware_by_ids(id, **kwargs)
@@ -298,13 +349,16 @@ class BaseOrganizer:
"""
raise NotImplementedError("get_learnware_path_by_ids is not implemented in BaseOrganizer")

def get_learnware_ids(self, top: int = None) -> List[str]:
def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]:
"""get the list of learnware ids

Parameters
----------
top : int, optional
the first top element to return, by default None
The first top element to return, by default None
check_status : int, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Raises
------
@@ -313,13 +367,16 @@ class BaseOrganizer:
"""
raise NotImplementedError("get_learnware_ids is not implemented in BaseOrganizer")

def get_learnwares(self, top: int = None) -> List[Learnware]:
def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]:
"""get the list of learnwares

Parameters
----------
top : int, optional
the first top element to return, by default None
The first top element to return, by default None
check_status : int, optional
- None: return all learnwares
- Others: return learnwares with check_status

Raises
------
@@ -334,18 +391,21 @@ class BaseOrganizer:

class BaseSearcher:
def __init__(self, organizer: BaseOrganizer = None):
self.learnware_oganizer = organizer
self.learnware_organizer = organizer

def reset(self, organizer):
self.learnware_oganizer = organizer
self.learnware_organizer = organizer

def __call__(self, user_info: BaseUserInfo):
"""Search learnwares based on user_info
def __call__(self, user_info: BaseUserInfo, check_status: int = None):
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
user_info : BaseUserInfo
user_info contains semantic_spec and stat_info
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status
"""
raise NotImplementedError("'__call__' method is not implemented in BaseSearcher")

@@ -356,10 +416,10 @@ class BaseChecker:
USABLE_LEARWARE = 1

def __init__(self, organizer: BaseOrganizer = None):
self.learnware_oganizer = organizer
self.learnware_organizer = organizer

def reset(self, organizer):
self.learnware_oganizer = organizer
self.learnware_organizer = organizer

def __call__(self, learnware: Learnware) -> int:
"""Check the utility of a learnware


+ 60
- 21
learnware/market/easy2/organizer.py View File

@@ -55,7 +55,9 @@ class EasyOrganizer(BaseOrganizer):
self.count,
) = self.dbops.load_market()

def add_learnware(self, zip_path: str, semantic_spec: dict, check_status: int) -> Tuple[str, int]:
def add_learnware(
self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None
) -> Tuple[str, int]:
"""Add a learnware into the market.

Parameters
@@ -66,7 +68,8 @@ class EasyOrganizer(BaseOrganizer):
semantic_spec for new learnware, in dictionary format.
check_status: int
A flag indicating whether the learnware is usable.

learnware_id: int
A id in database for learnware
Returns
-------
Tuple[str, int]
@@ -80,9 +83,9 @@ class EasyOrganizer(BaseOrganizer):
semantic_spec = copy.deepcopy(semantic_spec)
logger.info("Get new learnware from %s" % (zip_path))

id = "%08d" % (self.count)
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, id)
learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id)
copyfile(zip_path, target_zip_dir)

with zipfile.ZipFile(target_zip_dir, "r") as z_file:
@@ -91,7 +94,7 @@ class EasyOrganizer(BaseOrganizer):

try:
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
except:
try:
@@ -107,19 +110,19 @@ class EasyOrganizer(BaseOrganizer):
learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE

self.dbops.add_learnware(
id=id,
id=learnware_id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
use_flag=learnwere_status,
)

self.learnware_list[id] = new_learnware
self.learnware_zip_list[id] = target_zip_dir
self.learnware_folder_list[id] = target_folder_dir
self.use_flags[id] = learnwere_status
self.learnware_list[learnware_id] = new_learnware
self.learnware_zip_list[learnware_id] = target_zip_dir
self.learnware_folder_list[learnware_id] = target_folder_dir
self.use_flags[learnware_id] = learnwere_status
self.count += 1
return id, learnwere_status
return learnware_id, learnwere_status

def delete_learnware(self, id: str) -> bool:
"""Delete Learnware from market
@@ -205,7 +208,8 @@ class EasyOrganizer(BaseOrganizer):
if new_learnware is None:
return BaseChecker.INVALID_LEARNWARE

copyfile(zip_path, target_zip_dir)
if zip_path != target_zip_dir:
copyfile(zip_path, target_zip_dir)
with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)

@@ -284,17 +288,52 @@ class EasyOrganizer(BaseOrganizer):
logger.warning("Learnware ID '%s' NOT Found!" % (ids))
return None

def get_learnware_ids(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.keys())
else:
return list(self.learnware_list.keys())[:top]
def get_learnware_ids(self, top: int = None, check_status: int = None) -> List[str]:
"""Get learnware ids

Parameters
----------
top : int, optional
The first top learnware ids to return, by default None
check_status : bool, optional
- None: return all learnware ids
- Others: return learnware ids with check_status

Returns
-------
List[str]
Learnware ids
"""
if check_status is None:
filtered_ids = self.use_flags.keys()
elif check_status is True:
filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.USABLE_LEARWARE]
elif check_status is False:
filtered_ids = [key for key, value in self.use_flags.items() if value == BaseChecker.NONUSABLE_LEARNWARE]

def get_learnwares(self, top: int = None) -> List[str]:
if top is None:
return list(self.learnware_list.values())
return filtered_ids
else:
return list(self.learnware_list.values())[:top]
return filtered_ids[:top]

def get_learnwares(self, top: int = None, check_status: int = None) -> List[Learnware]:
"""Get learnware list

Parameters
----------
top : int, optional
The first top learnwares to return, by default None
check_status : bool, optional
- None: return all learnwares
- Others: return learnwares with check_status

Returns
-------
List[Learnware]
Learnware list
"""
learnware_ids = self.get_learnware_ids(top, check_status)
return [self.learnware_list[idx] for idx in learnware_ids]

def __len__(self):
return len(self.learnware_list)

+ 7
- 4
learnware/market/easy2/searcher.py View File

@@ -613,14 +613,14 @@ class EasySearcher(BaseSearcher):
self.stat_searcher = EasyStatSearcher(organizer)

def reset(self, organizer):
self.learnware_oganizer = organizer
self.learnware_organizer = organizer
self.semantic_searcher.reset(organizer)
self.stat_searcher.reset(organizer)

def __call__(
self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy"
self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy"
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
"""Search learnwares based on user_info
"""Search learnwares based on user_info from learnwares with check_status

Parameters
----------
@@ -628,6 +628,9 @@ class EasySearcher(BaseSearcher):
user_info contains semantic_spec and stat_info
max_search_num : int
The maximum number of the returned learnwares
check_status : int, optional
- None: search from all learnwares
- Others: search from learnwares with check_status

Returns
-------
@@ -637,7 +640,7 @@ class EasySearcher(BaseSearcher):
the third is the score of Learnware (mixture)
the fourth is the list of Learnware (mixture), the size is search_num
"""
learnware_list = self.learnware_oganizer.get_learnwares()
learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status)
learnware_list = self.semantic_searcher(learnware_list, user_info)

if len(learnware_list) == 0:


+ 1
- 1
learnware/market/module.py View File

@@ -10,7 +10,7 @@ MARKET_CONFIG = {
}


def instatiate_learnware_market(market_id, name="easy", **kwargs):
def instantiate_learnware_market(market_id="default", name="easy", **kwargs):
return LearnwareMarket(
market_id=market_id,
organizer=MARKET_CONFIG[name]["organizer"],


+ 2
- 2
tests/test_market/test_easy.py View File

@@ -11,7 +11,7 @@ from sklearn.model_selection import train_test_split
from shutil import copyfile, rmtree

import learnware
from learnware.market import instatiate_learnware_market, BaseUserInfo
from learnware.market import instantiate_learnware_market, BaseUserInfo
import learnware.specification as specification

curr_root = os.path.dirname(os.path.abspath(__file__))
@@ -43,7 +43,7 @@ class TestMarket(unittest.TestCase):

def _init_learnware_market(self):
"""initialize learnware market"""
easy_market = instatiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True)
easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True)
return easy_market

def test_prepare_learnware_randomly(self, learnware_num=5):


Loading…
Cancel
Save