| @@ -14,7 +14,7 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| os: [ubuntu-20.04] | |||
| python-version: [3.8, 3.9] | |||
| python-version: [3.9] | |||
| steps: | |||
| - name: Test learnware from pip | |||
| @@ -50,5 +50,4 @@ jobs: | |||
| - name: Test workflow | |||
| run: | | |||
| cd tests | |||
| conda run -n learnware python -m pytest test_workflow/test_workflow.py | |||
| conda run -n learnware python -m pytest tests/test_workflow/ | |||
| @@ -14,7 +14,7 @@ jobs: | |||
| strategy: | |||
| matrix: | |||
| os: [ubuntu-20.04] | |||
| python-version: [3.8, 3.9] | |||
| python-version: [3.9] | |||
| steps: | |||
| - name: Test learnware from pip | |||
| @@ -55,4 +55,4 @@ jobs: | |||
| - name: Test workflow | |||
| run: | | |||
| conda run -n learnware python -m pytest tests/test_workflow/test_workflow.py | |||
| conda run -n learnware python -m pytest tests/test_workflow/ | |||
| @@ -6,13 +6,13 @@ from get_data import * | |||
| import os | |||
| import random | |||
| from learnware.specification.image import RKMEImageSpecification | |||
| from learnware.specification import RKMEImageSpecification | |||
| from learnware.reuse.averaging import AveragingReuser | |||
| from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction | |||
| from learnware.learnware import Learnware | |||
| import time | |||
| from learnware.market import EasyMarket, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import database_ops | |||
| from learnware.learnware import Learnware | |||
| import learnware.specification as specification | |||
| @@ -122,7 +122,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo | |||
| def prepare_market(): | |||
| image_market = EasyMarket(market_id="cifar10", rebuild=True) | |||
| image_market = instantiate_learnware_market(market_id="cifar10", name="easy", rebuild=True) | |||
| try: | |||
| rmtree(learnware_pool_dir) | |||
| except: | |||
| @@ -148,10 +148,10 @@ def prepare_market(): | |||
| def test_search(gamma=0.1, load_market=True): | |||
| if load_market: | |||
| image_market = EasyMarket(market_id="cifar10") | |||
| image_market = instantiate_learnware_market(market_id="cifar10", name="easy") | |||
| else: | |||
| prepare_market() | |||
| image_market = EasyMarket(market_id="cifar10") | |||
| image_market = instantiate_learnware_market(market_id="cifar10", name="easy") | |||
| logger.info("Number of items in the market: %d" % len(image_market)) | |||
| select_list = [] | |||
| @@ -7,7 +7,7 @@ from tqdm import tqdm | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import EasyMarket, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import database_ops | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser | |||
| from learnware.specification import generate_rkme_spec | |||
| @@ -53,7 +53,7 @@ class M5DatasetWorkflow: | |||
| # database_ops.clear_learnware_table() | |||
| learnware.init() | |||
| easy_market = EasyMarket(rebuild=True) | |||
| easy_market = instantiate_learnware_market(name="easy", rebuild=True) | |||
| print("Total Item:", len(easy_market)) | |||
| zip_path_list = [] | |||
| @@ -125,7 +125,7 @@ class M5DatasetWorkflow: | |||
| self.prepare_learnware(regenerate_flag) | |||
| self._init_learnware_market() | |||
| easy_market = EasyMarket() | |||
| easy_market = instantiate_learnware_market(name="easy") | |||
| print("Total Item:", len(easy_market)) | |||
| m5 = DataLoader() | |||
| @@ -7,7 +7,7 @@ from tqdm import tqdm | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import EasyMarket, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser | |||
| from learnware.specification import generate_rkme_spec | |||
| from pfs import Dataloader | |||
| @@ -50,7 +50,7 @@ class PFSDatasetWorkflow: | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| learnware.init() | |||
| easy_market = EasyMarket(market_id="pfs", rebuild=True) | |||
| easy_market = instantiate_learnware_market(market_id="pfs", name="easy", rebuild=True) | |||
| print("Total Item:", len(easy_market)) | |||
| zip_path_list = [] | |||
| @@ -122,7 +122,7 @@ class PFSDatasetWorkflow: | |||
| self.prepare_learnware(regenerate_flag) | |||
| self._init_learnware_market() | |||
| easy_market = EasyMarket(market_id="pfs") | |||
| easy_market = instantiate_learnware_market(market_id="pfs", name="easy") | |||
| print("Total Item:", len(easy_market)) | |||
| pfs = Dataloader() | |||
| @@ -0,0 +1,3 @@ | |||
| torch==2.0.1 | |||
| torchdata | |||
| torchtext | |||
| @@ -85,12 +85,13 @@ def prepare_model(): | |||
| logger.info("Model saved to '%s'" % (model_save_path)) | |||
| def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name): | |||
| def prepare_learnware(data_path, model_path, init_file_path, yaml_path, env_file_path, save_root, zip_name): | |||
| os.makedirs(save_root, exist_ok=True) | |||
| tmp_spec_path = os.path.join(save_root, "rkme.json") | |||
| tmp_model_path = os.path.join(save_root, "model.pth") | |||
| tmp_yaml_path = os.path.join(save_root, "learnware.yaml") | |||
| tmp_init_path = os.path.join(save_root, "__init__.py") | |||
| tmp_env_path = os.path.join(save_root, "requirements.txt") | |||
| with open(data_path, "rb") as f: | |||
| X = pickle.load(f) | |||
| @@ -105,12 +106,14 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo | |||
| copyfile(model_path, tmp_model_path) | |||
| copyfile(yaml_path, tmp_yaml_path) | |||
| copyfile(init_file_path, tmp_init_path) | |||
| copyfile(env_file_path, tmp_env_path) | |||
| zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name)) | |||
| with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj: | |||
| zip_obj.write(tmp_spec_path, "rkme.json") | |||
| zip_obj.write(tmp_model_path, "model.pth") | |||
| zip_obj.write(tmp_yaml_path, "learnware.yaml") | |||
| zip_obj.write(tmp_init_path, "__init__.py") | |||
| zip_obj.write(tmp_env_path, "requirements.txt") | |||
| rmtree(save_root) | |||
| logger.info("New Learnware Saved to %s" % (zip_file_name)) | |||
| return zip_file_name | |||
| @@ -128,8 +131,9 @@ def prepare_market(): | |||
| model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) | |||
| init_file_path = "./example_files/example_init.py" | |||
| yaml_file_path = "./example_files/example_yaml.yaml" | |||
| env_file_path = "./example_files/requirements.txt" | |||
| new_learnware_path = prepare_learnware( | |||
| data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i) | |||
| data_path, model_path, init_file_path, yaml_file_path, env_file_path, tmp_dir, "%s_%d" % (dataset, i) | |||
| ) | |||
| semantic_spec = semantic_specs[0] | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (i) | |||
| @@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import EasyMarket, BaseUserInfo | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser | |||
| from learnware.specification import generate_rkme_spec, RKMETableSpecification | |||
| @@ -34,7 +34,7 @@ class LearnwareMarketWorkflow: | |||
| """initialize learnware market""" | |||
| learnware.init() | |||
| np.random.seed(2023) | |||
| easy_market = EasyMarket(market_id="sklearn_digits", rebuild=True) | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) | |||
| return easy_market | |||
| def prepare_learnware_randomly(self, learnware_num=5): | |||
| @@ -92,13 +92,13 @@ class LearnwareMarketWorkflow: | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| curr_inds = easy_market._get_ids() | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| easy_market.delete_learnware(learnware_id) | |||
| curr_inds = easy_market._get_ids() | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| return easy_market | |||
| @@ -1,4 +1,4 @@ | |||
| __version__ = "0.1.0.999" | |||
| __version__ = "0.1.1.99" | |||
| import os | |||
| from .logger import get_module_logger | |||
| @@ -80,19 +80,13 @@ semantic_config = { | |||
| "Values": [ | |||
| "Classification", | |||
| "Regression", | |||
| "Clustering", | |||
| "Feature Extraction", | |||
| # "Generation", | |||
| "Segmentation", | |||
| "Object Detection", | |||
| "Others", | |||
| ], | |||
| "Type": "Class", # Choose only one class | |||
| }, | |||
| # "Device": { | |||
| # "Values": ["CPU", "GPU"], | |||
| # "Type": "Tag", | |||
| # }, # Choose one or more tags | |||
| "Library": { | |||
| "Values": ["Scikit-learn", "PyTorch", "TensorFlow", "Others"], | |||
| "Type": "Class", | |||
| @@ -2,9 +2,8 @@ from .anchor import AnchoredUserInfo, AnchoredOrganizer | |||
| from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | |||
| from .evolve_anchor import EvolvedAnchoredOrganizer | |||
| from .evolve import EvolvedOrganizer | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .hetergeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| from .easy import EasyMarket | |||
| from .classes import CondaChecker | |||
| from .module import instantiate_learnware_market | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import List, Dict, Tuple, Any | |||
| from ..easy2.organizer import EasyOrganizer | |||
| from ..easy.organizer import EasyOrganizer | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| from ...specification import BaseStatSpecification | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import List, Dict, Tuple, Any, Union | |||
| from ..base import BaseUserInfo | |||
| from ..easy2.searcher import EasySearcher | |||
| from ..easy.searcher import EasySearcher | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| @@ -225,6 +225,9 @@ class LearnwareMarket: | |||
| """ | |||
| return self.learnware_organizer.get_learnwares(top, check_status, **kwargs) | |||
| def reload_learnware(self, learnware_id: str): | |||
| self.learnware_organizer.reload_learnware(learnware_id) | |||
| def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) | |||
| @@ -1,175 +0,0 @@ | |||
| from sqlalchemy.ext.declarative import declarative_base | |||
| from sqlalchemy import create_engine, text | |||
| from sqlalchemy import Column, Integer, Text, DateTime, String | |||
| import os | |||
| import json | |||
| from ..learnware import get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("database") | |||
| DeclarativeBase = declarative_base() | |||
| class Learnware(DeclarativeBase): | |||
| __tablename__ = "tb_learnware" | |||
| id = Column(String(10), primary_key=True, nullable=False) | |||
| semantic_spec = Column(Text, nullable=False) | |||
| zip_path = Column(Text, nullable=False) | |||
| folder_path = Column(Text, nullable=False) | |||
| use_flag = Column(Text, nullable=False) | |||
| pass | |||
| class DatabaseOperations(object): | |||
| def __init__(self, url: str, database_name: str): | |||
| if url.startswith("sqlite"): | |||
| url = os.path.join(url, f"{database_name}.db") | |||
| else: | |||
| url = f"{url}/{database_name}" | |||
| pass | |||
| self.url = url | |||
| self.create_database_if_not_exists(url) | |||
| pass | |||
| def create_database_if_not_exists(self, url): | |||
| database_exists = True | |||
| if url.startswith("sqlite"): | |||
| # it is sqlite | |||
| start = url.find(":///") | |||
| path = url[start + 4 :] | |||
| if os.path.exists(path): | |||
| database_exists = True | |||
| pass | |||
| else: | |||
| database_exists = False | |||
| os.makedirs(os.path.dirname(path), exist_ok=True) | |||
| pass | |||
| pass | |||
| elif self.url.startswith("postgresql"): | |||
| # it is postgresql | |||
| dbname_start = url.rfind("/") | |||
| dbname = url[dbname_start + 1 :] | |||
| url_no_dbname = url[:dbname_start] + "/postgres" | |||
| engine = create_engine(url_no_dbname) | |||
| with engine.connect() as conn: | |||
| result = conn.execute(text("SELECT datname FROM pg_database;")) | |||
| db_list = set() | |||
| for row in result.fetchall(): | |||
| db_list.add(row[0].lower()) | |||
| pass | |||
| if dbname.lower() not in db_list: | |||
| database_exists = False | |||
| conn.execution_options(isolation_level="AUTOCOMMIT").execute( | |||
| text("CREATE DATABASE {0};".format(dbname)) | |||
| ) | |||
| pass | |||
| else: | |||
| database_exists = True | |||
| pass | |||
| pass | |||
| engine.dispose() | |||
| pass | |||
| else: | |||
| raise Exception(f"Unsupported database url: {self.url}") | |||
| pass | |||
| self.engine = create_engine(url, future=True) | |||
| if not database_exists: | |||
| DeclarativeBase.metadata.create_all(self.engine) | |||
| pass | |||
| pass | |||
| def clear_learnware_table(self): | |||
| with self.engine.connect() as conn: | |||
| conn.execute(text("DELETE FROM tb_learnware;")) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def add_learnware(self, id: str, semantic_spec: dict, zip_path, folder_path, use_flag: str): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| conn.execute( | |||
| text( | |||
| ( | |||
| "INSERT INTO tb_learnware (id, semantic_spec, zip_path, folder_path, use_flag)" | |||
| "VALUES (:id, :semantic_spec, :zip_path, :folder_path, :use_flag);" | |||
| ) | |||
| ), | |||
| dict( | |||
| id=id, | |||
| semantic_spec=semantic_spec_str, | |||
| zip_path=zip_path, | |||
| folder_path=folder_path, | |||
| use_flag=use_flag, | |||
| ), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def delete_learnware(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| conn.execute(text("DELETE FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def update_learnware_semantic_specification(self, id: str, semantic_spec: dict): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| r = conn.execute( | |||
| text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"), | |||
| dict(id=id, semantic_spec=semantic_spec_str), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def update_learnware_use_flag(self, id: str, semantic_spec: dict): | |||
| with self.engine.connect() as conn: | |||
| semantic_spec_str = json.dumps(semantic_spec) | |||
| r = conn.execute( | |||
| text("UPDATE tb_learnware SET semantic_spec=:semantic_spec WHERE id=:id;"), | |||
| dict(id=id, semantic_spec=semantic_spec_str), | |||
| ) | |||
| conn.commit() | |||
| pass | |||
| pass | |||
| def load_market(self): | |||
| with self.engine.connect() as conn: | |||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | |||
| learnware_list = {} | |||
| zip_list = {} | |||
| folder_list = {} | |||
| max_count = 0 | |||
| for id, semantic_spec, zip_path, folder_path, use_flag in cursor: | |||
| id = id.strip() | |||
| semantic_spec_dict = json.loads(semantic_spec) | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec_dict, learnware_dirpath=folder_path | |||
| ) | |||
| logger.info(f"Load learnware: {id}") | |||
| learnware_list[id] = new_learnware | |||
| # assert new_learnware is not None | |||
| zip_list[id] = zip_path | |||
| folder_list[id] = folder_path | |||
| max_count = max(max_count, int(id)) | |||
| pass | |||
| return learnware_list, zip_list, folder_list, max_count + 1 | |||
| pass | |||
| pass | |||
| @@ -1,987 +0,0 @@ | |||
| import os | |||
| import json | |||
| import copy | |||
| import torch | |||
| import zipfile | |||
| import traceback | |||
| import numpy as np | |||
| import pandas as pd | |||
| from rapidfuzz import fuzz | |||
| from cvxopt import solvers, matrix | |||
| from shutil import copyfile, rmtree | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from .base import LearnwareMarket, BaseUserInfo | |||
| from .database_ops import DatabaseOperations | |||
| from .. import utils | |||
| from ..config import C as conf | |||
| from ..logger import get_module_logger | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..specification import RKMETableSpecification, Specification | |||
| logger = get_module_logger("market", "INFO") | |||
| class EasyMarket(LearnwareMarket): | |||
| """EasyMarket provide an easy and simple implementation for LearnwareMarket | |||
| - EasyMarket stores learnwares with file system and database | |||
| - EasyMarket search the learnwares with the match of semantical tag and the statistical RKME | |||
| - EasyMarket does not support the search between heterogeneous features learnwars | |||
| """ | |||
| INVALID_LEARNWARE = -1 | |||
| NONUSABLE_LEARNWARE = 0 | |||
| USABLE_LEARWARE = 1 | |||
| def __init__(self, market_id: str = "default", rebuild: bool = False): | |||
| """Initialize Learnware Market. | |||
| Automatically reload from db if available. | |||
| Build an empty db otherwise. | |||
| Parameters | |||
| ---------- | |||
| market_id : str, optional, by default 'default' | |||
| The unique market id for market database | |||
| rebuild : bool, optional | |||
| Clear current database if set to True, by default False | |||
| !!! Do NOT set to True unless highly necessary !!! | |||
| """ | |||
| self.market_id = market_id | |||
| self.market_store_path = os.path.join(conf.market_root_path, self.market_id) | |||
| self.learnware_pool_path = os.path.join(self.market_store_path, "learnware_pool") | |||
| self.learnware_zip_pool_path = os.path.join(self.learnware_pool_path, "zips") | |||
| self.learnware_folder_pool_path = os.path.join(self.learnware_pool_path, "unzipped_learnwares") | |||
| self.learnware_list = {} # id: Learnware | |||
| self.learnware_zip_list = {} | |||
| self.learnware_folder_list = {} | |||
| self.count = 0 | |||
| self.semantic_spec_list = conf.semantic_specs | |||
| self.dbops = DatabaseOperations(conf.database_url, "market_" + self.market_id) | |||
| self.reload_market(rebuild=rebuild) # Automatically reload the market | |||
| logger.info("Market Initialized!") | |||
| def reload_market(self, rebuild: bool = False) -> bool: | |||
| if rebuild: | |||
| logger.warning("Warning! You are trying to clear current database!") | |||
| try: | |||
| self.dbops.clear_learnware_table() | |||
| rmtree(self.learnware_pool_path) | |||
| except: | |||
| pass | |||
| os.makedirs(self.learnware_pool_path, exist_ok=True) | |||
| os.makedirs(self.learnware_zip_pool_path, exist_ok=True) | |||
| os.makedirs(self.learnware_folder_pool_path, exist_ok=True) | |||
| self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = self.dbops.load_market() | |||
| @classmethod | |||
| def check_learnware(cls, learnware: Learnware) -> int: | |||
| """Check the utility of a learnware | |||
| Parameters | |||
| ---------- | |||
| learnware : Learnware | |||
| Returns | |||
| ------- | |||
| int | |||
| A flag indicating whether the learnware can be accepted. | |||
| - The INVALID_LEARNWARE denotes the learnware does not pass the check | |||
| - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency | |||
| - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction | |||
| """ | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| try: | |||
| # check model instantiation | |||
| learnware.instantiate_model() | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") | |||
| return cls.NONUSABLE_LEARNWARE | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # check input shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table": | |||
| input_shape = (semantic_spec["Input"]["Dimension"],) | |||
| else: | |||
| input_shape = learnware_model.input_shape | |||
| pass | |||
| # check rkme dimension | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") | |||
| if stat_spec is not None: | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") | |||
| return cls.NONUSABLE_LEARNWARE | |||
| pass | |||
| inputs = np.random.randn(10, *input_shape) | |||
| outputs = learnware.predict(inputs) | |||
| # check output | |||
| if outputs.ndim == 1: | |||
| outputs = outputs.reshape(-1, 1) | |||
| pass | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): | |||
| # check output type | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| if not isinstance(outputs, np.ndarray): | |||
| logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor") | |||
| return cls.NONUSABLE_LEARNWARE | |||
| # check output shape | |||
| output_dim = int(semantic_spec["Output"]["Dimension"]) | |||
| if outputs[0].shape[0] != output_dim: | |||
| logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") | |||
| return cls.NONUSABLE_LEARNWARE | |||
| pass | |||
| else: | |||
| if outputs.shape[1:] != learnware_model.output_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") | |||
| return cls.NONUSABLE_LEARNWARE | |||
| except Exception as e: | |||
| logger.exception | |||
| logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") | |||
| raise e | |||
| return cls.NONUSABLE_LEARNWARE | |||
| return cls.USABLE_LEARWARE | |||
| def add_learnware(self, zip_path: str, semantic_spec: dict) -> Tuple[str, bool]: | |||
| """Add a learnware into the market. | |||
| .. note:: | |||
| Given a prediction of a certain time, all signals before this time will be prepared well. | |||
| Parameters | |||
| ---------- | |||
| zip_path : str | |||
| Filepath for learnware model, a zipped file. | |||
| semantic_spec : dict | |||
| semantic_spec for new learnware, in dictionary format. | |||
| Returns | |||
| ------- | |||
| Tuple[str, int] | |||
| - str indicating model_id | |||
| - int indicating what the flag of learnware is added. | |||
| """ | |||
| semantic_spec = copy.deepcopy(semantic_spec) | |||
| if not os.path.exists(zip_path): | |||
| logger.warning("Zip Path NOT Found! Fail to add learnware.") | |||
| return None, self.INVALID_LEARNWARE | |||
| try: | |||
| if len(semantic_spec["Data"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Data.") | |||
| return None, self.INVALID_LEARNWARE | |||
| if len(semantic_spec["Task"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Task.") | |||
| return None, self.INVALID_LEARNWARE | |||
| if len(semantic_spec["Library"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please choose Device.") | |||
| return None, self.INVALID_LEARNWARE | |||
| if len(semantic_spec["Name"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please provide Name.") | |||
| return None, self.INVALID_LEARNWARE | |||
| if len(semantic_spec["Description"]["Values"]) == 0 and len(semantic_spec["Scenario"]["Values"]) == 0: | |||
| logger.warning("Illegal semantic specification, please provide Scenario or Description.") | |||
| return None, self.INVALID_LEARNWARE | |||
| if ( | |||
| semantic_spec["Data"]["Type"] != "Class" | |||
| or semantic_spec["Task"]["Type"] != "Class" | |||
| or semantic_spec["Library"]["Type"] != "Class" | |||
| or semantic_spec["Scenario"]["Type"] != "Tag" | |||
| or semantic_spec["Name"]["Type"] != "String" | |||
| or semantic_spec["Description"]["Type"] != "String" | |||
| ): | |||
| logger.warning("Illegal semantic specification, please provide the right type.") | |||
| return None, self.INVALID_LEARNWARE | |||
| except: | |||
| logger.info(f"Semantic specification: {semantic_spec}") | |||
| logger.warning("Illegal semantic specification, some keys are missing.") | |||
| return None, self.INVALID_LEARNWARE | |||
| logger.info("Get new learnware from %s" % (zip_path)) | |||
| id = "%08d" % (self.count) | |||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) | |||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) | |||
| copyfile(zip_path, target_zip_dir) | |||
| with zipfile.ZipFile(target_zip_dir, "r") as z_file: | |||
| z_file.extractall(target_folder_dir) | |||
| logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir)) | |||
| try: | |||
| new_learnware = get_learnware_from_dirpath( | |||
| id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| ) | |||
| except: | |||
| try: | |||
| os.remove(target_zip_dir) | |||
| rmtree(target_folder_dir) | |||
| except: | |||
| pass | |||
| return None, self.INVALID_LEARNWARE | |||
| if new_learnware is None: | |||
| return None, self.INVALID_LEARNWARE | |||
| check_flag = self.check_learnware(new_learnware) | |||
| self.dbops.add_learnware( | |||
| id=id, | |||
| semantic_spec=semantic_spec, | |||
| zip_path=target_zip_dir, | |||
| folder_path=target_folder_dir, | |||
| use_flag=check_flag, | |||
| ) | |||
| self.learnware_list[id] = new_learnware | |||
| self.learnware_zip_list[id] = target_zip_dir | |||
| self.learnware_folder_list[id] = target_folder_dir | |||
| self.count += 1 | |||
| return id, check_flag | |||
| def _convert_dist_to_score( | |||
| self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 | |||
| ) -> List[float]: | |||
| """Convert mmd dist list into min_max score list | |||
| Parameters | |||
| ---------- | |||
| dist_list : List[float] | |||
| The list of mmd distances from learnware rkmes to user rkme | |||
| dist_epsilon: float | |||
| The paramter for converting mmd dist to score | |||
| min_score: float | |||
| The minimum score for maximum returned score | |||
| Returns | |||
| ------- | |||
| List[float] | |||
| The list of min_max scores of each learnware | |||
| """ | |||
| if len(dist_list) == 0: | |||
| return [] | |||
| min_dist, max_dist = min(dist_list), max(dist_list) | |||
| if min_dist == max_dist: | |||
| return [1 for dist in dist_list] | |||
| else: | |||
| max_score = (max_dist - min_dist) / (max_dist - dist_epsilon) | |||
| if min_dist < dist_epsilon: | |||
| dist_epsilon = min_dist | |||
| elif max_score < min_score: | |||
| dist_epsilon = max_dist - (max_dist - min_dist) / min_score | |||
| return [(max_dist - dist) / (max_dist - dist_epsilon) for dist in dist_list] | |||
| def _calculate_rkme_spec_mixture_weight( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMETableSpecification, | |||
| intermediate_K: np.ndarray = None, | |||
| intermediate_C: np.ndarray = None, | |||
| ) -> Tuple[List[float], float]: | |||
| """Calculate mixture weight for the learnware_list based on a user's rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| A list of existing learnwares | |||
| user_rkme : RKMETableSpecification | |||
| User RKME statistical specification | |||
| intermediate_K : np.ndarray, optional | |||
| Intermediate kernel matrix K, by default None | |||
| intermediate_C : np.ndarray, optional | |||
| Intermediate inner product vector C, by default None | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], float] | |||
| The first is the list of mixture weights | |||
| The second is the mmd dist between the mixture of learnware rkmes and the user's rkme | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||
| ] | |||
| if type(intermediate_K) == np.ndarray: | |||
| K = intermediate_K | |||
| else: | |||
| K = np.zeros((learnware_num, learnware_num)) | |||
| for i in range(K.shape[0]): | |||
| K[i, i] = RKME_list[i].inner_prod(RKME_list[i]) | |||
| for j in range(i + 1, K.shape[0]): | |||
| K[i, j] = K[j, i] = RKME_list[i].inner_prod(RKME_list[j]) | |||
| if type(intermediate_C) == np.ndarray: | |||
| C = intermediate_C | |||
| else: | |||
| C = np.zeros((learnware_num, 1)) | |||
| for i in range(C.shape[0]): | |||
| C[i, 0] = user_rkme.inner_prod(RKME_list[i]) | |||
| K = torch.from_numpy(K).double().to(user_rkme.device) | |||
| C = torch.from_numpy(C).double().to(user_rkme.device) | |||
| # beta can be negative | |||
| # weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C | |||
| # beta must be nonnegative | |||
| n = K.shape[0] | |||
| P = matrix(K.cpu().numpy()) | |||
| q = matrix(-C.cpu().numpy()) | |||
| G = matrix(-np.eye(n)) | |||
| h = matrix(np.zeros((n, 1))) | |||
| A = matrix(np.ones((1, n))) | |||
| b = matrix(np.ones((1, 1))) | |||
| solvers.options["show_progress"] = False | |||
| sol = solvers.qp(P, q, G, h, A, b) | |||
| weight = np.array(sol["x"]) | |||
| weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device) | |||
| score = user_rkme.inner_prod(user_rkme) + 2 * sol["primal objective"] | |||
| return weight.detach().cpu().numpy().reshape(-1), score | |||
| def _calculate_intermediate_K_and_C( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMETableSpecification, | |||
| intermediate_K: np.ndarray = None, | |||
| intermediate_C: np.ndarray = None, | |||
| ) -> Tuple[np.ndarray, np.ndarray]: | |||
| """Incrementally update the values of intermediate_K and intermediate_C | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares up till now | |||
| user_rkme : RKMETableSpecification | |||
| User RKME statistical specification | |||
| intermediate_K : np.ndarray, optional | |||
| Intermediate kernel matrix K, by default None | |||
| intermediate_C : np.ndarray, optional | |||
| Intermediate inner product vector C, by default None | |||
| Returns | |||
| ------- | |||
| Tuple[np.ndarray, np.ndarray] | |||
| The first is the intermediate value of K | |||
| The second is the intermediate value of C | |||
| """ | |||
| num = intermediate_K.shape[0] - 1 | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||
| ] | |||
| for i in range(intermediate_K.shape[0]): | |||
| intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | |||
| intermediate_C[num, 0] = user_rkme.inner_prod(RKME_list[-1]) | |||
| return intermediate_K, intermediate_C | |||
| def _search_by_rkme_spec_mixture_auto( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMETableSpecification, | |||
| max_search_num: int, | |||
| weight_cutoff: float = 0.98, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| """Select learnwares based on a total mixture ratio, then recalculate their mixture weights | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| User RKME statistical specification | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| weight_cutoff : float, optional | |||
| The ratio for selecting out the mose relevant learnwares, by default 0.9 | |||
| Returns | |||
| ------- | |||
| Tuple[float, List[float], List[Learnware]] | |||
| The first is the mixture mmd dist | |||
| The second is the list of weight | |||
| The third is the list of Learnware | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| if learnware_num == 0: | |||
| return [], [] | |||
| if learnware_num < max_search_num: | |||
| logger.warning("Available Learnware num less than search_num!") | |||
| max_search_num = learnware_num | |||
| weight, _ = self._calculate_rkme_spec_mixture_weight(learnware_list, user_rkme) | |||
| sort_by_weight_idx_list = sorted(range(learnware_num), key=lambda k: weight[k], reverse=True) | |||
| weight_sum = 0 | |||
| mixture_list = [] | |||
| for idx in sort_by_weight_idx_list: | |||
| weight_sum += weight[idx] | |||
| if weight_sum <= weight_cutoff: | |||
| mixture_list.append(learnware_list[idx]) | |||
| else: | |||
| break | |||
| if len(mixture_list) <= 1: | |||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | |||
| mixture_weight = [1] | |||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) | |||
| else: | |||
| if len(mixture_list) > max_search_num: | |||
| mixture_list = mixture_list[:max_search_num] | |||
| mixture_weight, mmd_dist = self._calculate_rkme_spec_mixture_weight(mixture_list, user_rkme) | |||
| return mmd_dist, mixture_weight, mixture_list | |||
| def _filter_by_rkme_spec_single( | |||
| self, | |||
| sorted_score_list: List[float], | |||
| learnware_list: List[Learnware], | |||
| filter_score: float = 0.5, | |||
| min_num: int = 15, | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Filter search result of _search_by_rkme_spec_single | |||
| Parameters | |||
| ---------- | |||
| sorted_score_list : List[float] | |||
| The list of score transformed by mmd dist | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| filter_score: float | |||
| The learnware whose score is lower than filter_score will be filtered | |||
| min_num: int | |||
| The minimum number of returned learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware]] | |||
| the first is the list of score | |||
| the second is the list of Learnware | |||
| """ | |||
| idx = min(min_num, len(learnware_list)) | |||
| while idx < len(learnware_list): | |||
| if sorted_score_list[idx] < filter_score: | |||
| break | |||
| idx = idx + 1 | |||
| return sorted_score_list[:idx], learnware_list[:idx] | |||
| def _filter_by_rkme_spec_dimension( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||
| ) -> List[Learnware]: | |||
| """Filter learnwares whose rkme dimension different from user_rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| User RKME statistical specification | |||
| Returns | |||
| ------- | |||
| List[Learnware] | |||
| Learnwares whose rkme dimensions equal user_rkme in user_info | |||
| """ | |||
| filtered_learnware_list = [] | |||
| user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | |||
| for learnware in learnware_list: | |||
| rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") | |||
| rkme_dim = str(list(rkme.get_z().shape)[1:]) | |||
| if rkme_dim == user_rkme_dim: | |||
| filtered_learnware_list.append(learnware) | |||
| return filtered_learnware_list | |||
| def _search_by_rkme_spec_mixture_greedy( | |||
| self, | |||
| learnware_list: List[Learnware], | |||
| user_rkme: RKMETableSpecification, | |||
| max_search_num: int, | |||
| score_cutoff: float = 0.001, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| """Greedily match learnwares such that their mixture become closer and closer to user's rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| User RKME statistical specification | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| score_cutof: float | |||
| The minimum mmd dist as threshold to stop further rkme_spec matching | |||
| Returns | |||
| ------- | |||
| Tuple[float, List[float], List[Learnware]] | |||
| The first is the mixture mmd dist | |||
| The second is the list of weight | |||
| The third is the list of Learnware | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| if learnware_num == 0: | |||
| return None, [], [] | |||
| if learnware_num < max_search_num: | |||
| logger.warning("Available Learnware num less than search_num!") | |||
| max_search_num = learnware_num | |||
| flag_list = [0 for _ in range(learnware_num)] | |||
| mixture_list, mmd_dist = [], None | |||
| intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) | |||
| for k in range(max_search_num): | |||
| idx_min, score_min = -1, -1 | |||
| weight_min = None | |||
| mixture_list.append(None) | |||
| if k != 0: | |||
| intermediate_K = np.c_[intermediate_K, np.zeros((k, 1))] | |||
| intermediate_K = np.r_[intermediate_K, np.zeros((1, k + 1))] | |||
| intermediate_C = np.r_[intermediate_C, np.zeros((1, 1))] | |||
| for idx in range(len(learnware_list)): | |||
| if flag_list[idx] == 0: | |||
| mixture_list[-1] = learnware_list[idx] | |||
| intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| weight, score = self._calculate_rkme_spec_mixture_weight( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| if idx_min == -1 or score < score_min: | |||
| idx_min, score_min, weight_min = idx, score, weight | |||
| mmd_dist = score_min | |||
| mixture_list[-1] = learnware_list[idx_min] | |||
| if score_min < score_cutoff: | |||
| break | |||
| else: | |||
| flag_list[idx_min] = 1 | |||
| intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( | |||
| mixture_list, user_rkme, intermediate_K, intermediate_C | |||
| ) | |||
| return mmd_dist, weight_min, mixture_list | |||
| def _search_by_rkme_spec_single( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| user RKME statistical specification | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware]] | |||
| the first is the list of mmd dist | |||
| the second is the list of Learnware | |||
| both lists are sorted by mmd dist | |||
| """ | |||
| RKME_list = [ | |||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||
| ] | |||
| mmd_dist_list = [] | |||
| for RKME in RKME_list: | |||
| mmd_dist = RKME.dist(user_rkme) | |||
| mmd_dist_list.append(mmd_dist) | |||
| sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) | |||
| sorted_dist_list = [mmd_dist_list[idx] for idx in sorted_idx_list] | |||
| sorted_learnware_list = [learnware_list[idx] for idx in sorted_idx_list] | |||
| return sorted_dist_list, sorted_learnware_list | |||
| def _search_by_semantic_spec_exact( | |||
| self, learnware_list: List[Learnware], user_info: BaseUserInfo | |||
| ) -> List[Learnware]: | |||
| def match_semantic_spec(semantic_spec1, semantic_spec2): | |||
| """ | |||
| semantic_spec1: semantic spec input by user | |||
| semantic_spec2: semantic spec in database | |||
| """ | |||
| if semantic_spec1.keys() != semantic_spec2.keys(): | |||
| # sematic spec in database may contain more keys than user input | |||
| pass | |||
| name2 = semantic_spec2["Name"]["Values"].lower() | |||
| description2 = semantic_spec2["Description"]["Values"].lower() | |||
| for key in semantic_spec1.keys(): | |||
| v1 = semantic_spec1[key]["Values"] | |||
| v2 = semantic_spec2[key]["Values"] | |||
| if len(v1) == 0: | |||
| # user input is empty, no need to search | |||
| continue | |||
| if key in ("Name", "Description"): | |||
| v1 = v1.lower() | |||
| if v1 not in name2 and v1 not in description2: | |||
| return False | |||
| pass | |||
| else: | |||
| if len(v2) == 0: | |||
| # user input contains some key that is not in database | |||
| return False | |||
| if semantic_spec1[key]["Type"] == "Class": | |||
| if isinstance(v1, list): | |||
| v1 = v1[0] | |||
| if isinstance(v2, list): | |||
| v2 = v2[0] | |||
| if v1 != v2: | |||
| return False | |||
| elif semantic_spec1[key]["Type"] == "Tag": | |||
| if not (set(v1) & set(v2)): | |||
| return False | |||
| pass | |||
| pass | |||
| pass | |||
| return True | |||
| match_learnwares = [] | |||
| for learnware in learnware_list: | |||
| learnware_semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| user_semantic_spec = user_info.get_semantic_spec() | |||
| if match_semantic_spec(user_semantic_spec, learnware_semantic_spec): | |||
| match_learnwares.append(learnware) | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) | |||
| return match_learnwares | |||
| def _search_by_semantic_spec_fuzz( | |||
| self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0 | |||
| ) -> List[Learnware]: | |||
| """Search learnware by fuzzy matching of semantic spec | |||
| Parameters | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec | |||
| max_num : int, optional | |||
| maximum number of learnwares returned, by default 50000 | |||
| min_score : float, optional | |||
| Minimum fuzzy matching score of learnwares returned, by default 30.0 | |||
| Returns | |||
| ------- | |||
| List[Learnware] | |||
| The list of returned learnwares | |||
| """ | |||
| def _match_semantic_spec_tag(semantic_spec1, semantic_spec2) -> bool: | |||
| """Judge if tags of two semantic specs are consistent | |||
| Parameters | |||
| ---------- | |||
| semantic_spec1 : | |||
| semantic spec input by user | |||
| semantic_spec2 : | |||
| semantic spec in database | |||
| Returns | |||
| ------- | |||
| bool | |||
| consistent (True) or not consistent (False) | |||
| """ | |||
| for key in semantic_spec1.keys(): | |||
| v1 = semantic_spec1[key]["Values"] | |||
| v2 = semantic_spec2[key]["Values"] | |||
| if len(v1) == 0: | |||
| # user input is empty, no need to search | |||
| continue | |||
| if key not in "Name": | |||
| if len(v2) == 0: | |||
| # user input contains some key that is not in database | |||
| return False | |||
| if semantic_spec1[key]["Type"] == "Class": | |||
| if isinstance(v1, list): | |||
| v1 = v1[0] | |||
| if isinstance(v2, list): | |||
| v2 = v2[0] | |||
| if v1 != v2: | |||
| return False | |||
| elif semantic_spec1[key]["Type"] == "Tag": | |||
| if not (set(v1) & set(v2)): | |||
| return False | |||
| return True | |||
| matched_learnware_tag = [] | |||
| final_result = [] | |||
| user_semantic_spec = user_info.get_semantic_spec() | |||
| for learnware in learnware_list: | |||
| learnware_semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| if _match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec): | |||
| matched_learnware_tag.append(learnware) | |||
| if len(matched_learnware_tag) > 0: | |||
| if "Name" in user_semantic_spec: | |||
| name_user = user_semantic_spec["Name"]["Values"].lower() | |||
| if len(name_user) > 0: | |||
| # Exact search | |||
| name_list = [ | |||
| learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() | |||
| for learnware in matched_learnware_tag | |||
| ] | |||
| des_list = [ | |||
| learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() | |||
| for learnware in matched_learnware_tag | |||
| ] | |||
| matched_learnware_exact = [] | |||
| for i in range(len(name_list)): | |||
| if name_user in name_list[i] or name_user in des_list[i]: | |||
| matched_learnware_exact.append(matched_learnware_tag[i]) | |||
| if len(matched_learnware_exact) == 0: | |||
| # Fuzzy search | |||
| matched_learnware_fuzz, fuzz_scores = [], [] | |||
| for i in range(len(name_list)): | |||
| score_name = fuzz.partial_ratio(name_user, name_list[i]) | |||
| score_des = fuzz.partial_ratio(name_user, des_list[i]) | |||
| final_score = max(score_name, score_des) | |||
| if final_score >= min_score: | |||
| matched_learnware_fuzz.append(matched_learnware_tag[i]) | |||
| fuzz_scores.append(final_score) | |||
| # Sort by score | |||
| sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[ | |||
| :max_num | |||
| ] | |||
| final_result = [matched_learnware_fuzz[idx] for idx in sort_idx] | |||
| else: | |||
| final_result = matched_learnware_exact | |||
| else: | |||
| final_result = matched_learnware_tag | |||
| else: | |||
| final_result = matched_learnware_tag | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) | |||
| return final_result | |||
| def search_learnware( | |||
| self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| """Search learnwares based on user_info | |||
| Parameters | |||
| ---------- | |||
| user_info : BaseUserInfo | |||
| user_info contains semantic_spec and stat_info | |||
| max_search_num : int | |||
| The maximum number of the returned learnwares | |||
| Returns | |||
| ------- | |||
| Tuple[List[float], List[Learnware], float, List[Learnware]] | |||
| the first is the sorted list of rkme dist | |||
| the second is the sorted list of Learnware (single) by the rkme dist | |||
| the third is the score of Learnware (mixture) | |||
| the fourth is the list of Learnware (mixture), the size is search_num | |||
| """ | |||
| learnware_list = [self.learnware_list[key] for key in self.learnware_list] | |||
| # learnware_list = self._search_by_semantic_spec_exact(learnware_list, user_info) | |||
| # if len(learnware_list) == 0: | |||
| logger.info(f"stat_info in user_info: {user_info.stat_info}") | |||
| learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) | |||
| logger.info(f"Number of learnwares after semantic fuzzy search: {len(learnware_list)}") | |||
| if "RKMETableSpecification" not in user_info.stat_info: | |||
| return None, learnware_list, 0.0, None | |||
| elif len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| else: | |||
| user_rkme = user_info.stat_info["RKMETableSpecification"] | |||
| learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | |||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| if search_method == "auto": | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_auto( | |||
| learnware_list, user_rkme, max_search_num | |||
| ) | |||
| elif search_method == "greedy": | |||
| mixture_dist, weight_list, mixture_learnware_list = self._search_by_rkme_spec_mixture_greedy( | |||
| learnware_list, user_rkme, max_search_num | |||
| ) | |||
| else: | |||
| logger.warning("f{search_method} not supported!") | |||
| mixture_dist = None | |||
| weight_list = [] | |||
| mixture_learnware_list = [] | |||
| if mixture_dist is None: | |||
| sorted_score_list = self._convert_dist_to_score(sorted_dist_list) | |||
| mixture_score = None | |||
| else: | |||
| merge_score_list = self._convert_dist_to_score(sorted_dist_list + [mixture_dist]) | |||
| sorted_score_list = merge_score_list[:-1] | |||
| mixture_score = merge_score_list[-1] | |||
| logger.info(f"After search by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| # filter learnware with low score | |||
| sorted_score_list, single_learnware_list = self._filter_by_rkme_spec_single( | |||
| sorted_score_list, single_learnware_list | |||
| ) | |||
| logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list | |||
| def delete_learnware(self, id: str) -> bool: | |||
| """Delete Learnware from market | |||
| Parameters | |||
| ---------- | |||
| id : str | |||
| Learnware to be deleted | |||
| Returns | |||
| ------- | |||
| bool | |||
| True for successful operation. | |||
| False for id not found. | |||
| """ | |||
| if not id in self.learnware_list: | |||
| logger.warning("Learnware id:'{}' NOT Found!".format(id)) | |||
| return False | |||
| zip_dir = self.learnware_zip_list[id] | |||
| os.remove(zip_dir) | |||
| folder_dir = self.learnware_folder_list[id] | |||
| rmtree(folder_dir) | |||
| self.learnware_list.pop(id) | |||
| self.learnware_zip_list.pop(id) | |||
| self.learnware_folder_list.pop(id) | |||
| self.dbops.delete_learnware(id=id) | |||
| return True | |||
| def get_semantic_spec_list(self) -> dict: | |||
| return self.semantic_spec_list | |||
| def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """Search learnware by id or list of ids. | |||
| Parameters | |||
| ---------- | |||
| ids : Union[str, List[str]] | |||
| Give a id or a list of ids | |||
| str: id of targer learware | |||
| List[str]: A list of ids of target learnwares | |||
| Returns | |||
| ------- | |||
| Union[Learnware, List[Learnware]] | |||
| Return target learnware or list of target learnwares. | |||
| None for Learnware NOT Found. | |||
| """ | |||
| if isinstance(ids, list): | |||
| ret = [] | |||
| for id in ids: | |||
| if id in self.learnware_list: | |||
| ret.append(self.learnware_list[id]) | |||
| else: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (id)) | |||
| ret.append(None) | |||
| return ret | |||
| else: | |||
| try: | |||
| return self.learnware_list[ids] | |||
| except: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: | |||
| """Get Zipped Learnware file by id | |||
| Parameters | |||
| ---------- | |||
| ids : Union[str, List[str]] | |||
| Give a id or a list of ids | |||
| str: id of targer learware | |||
| List[str]: A list of ids of target learnwares | |||
| Returns | |||
| ------- | |||
| Union[Learnware, List[Learnware]] | |||
| Return the path for target learnware or list of path. | |||
| None for Learnware NOT Found. | |||
| """ | |||
| if isinstance(ids, list): | |||
| ret = [] | |||
| for id in ids: | |||
| if id in self.learnware_zip_list: | |||
| ret.append(self.learnware_zip_list[id]) | |||
| else: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (id)) | |||
| ret.append(None) | |||
| return ret | |||
| else: | |||
| try: | |||
| return self.learnware_zip_list[ids] | |||
| except: | |||
| logger.warning("Learnware ID '%s' NOT Found!" % (ids)) | |||
| return None | |||
| def update_learnware_semantic_specification(self, learnware_id: str, semantic_spec: dict) -> bool: | |||
| """Update Learnware semantic_spec""" | |||
| # update database | |||
| self.dbops.update_learnware_semantic_specification(learnware_id=learnware_id, semantic_spec=semantic_spec) | |||
| # update file | |||
| folder_path = self.learnware_folder_list[learnware_id] | |||
| with open(os.path.join(folder_path, "semantic_specification.json"), "w") as f: | |||
| json.dump(semantic_spec, f) | |||
| pass | |||
| # update zip | |||
| zip_path = self.learnware_zip_list[learnware_id] | |||
| utils.zip_learnware_folder(folder_path, zip_path) | |||
| # update learnware | |||
| self.learnware_list[learnware_id].update_semantic_spec(semantic_spec) | |||
| pass | |||
| def __len__(self): | |||
| return len(self.learnware_list.keys()) | |||
| def _get_ids(self, top=None): | |||
| if top is None: | |||
| return list(self.learnware_list.keys()) | |||
| else: | |||
| return list(self.learnware_list.keys())[:top] | |||
| @@ -43,7 +43,7 @@ class EasySemanticChecker(BaseChecker): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: | |||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: | |||
| assert semantic_spec["Output"] is not None, "Lack of output semantics" | |||
| dim = semantic_spec["Output"]["Dimension"] | |||
| for k, v in semantic_spec["Output"]["Description"].items(): | |||
| @@ -126,7 +126,7 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(f"learnware {learnware} prediction method is not valid!") | |||
| return self.INVALID_LEARNWARE | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): | |||
| # Check output type | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| @@ -145,6 +145,28 @@ class DatabaseOperations(object): | |||
| pass | |||
| pass | |||
| def get_learnware_semantic_specification(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| r = conn.execute(text("SELECT semantic_spec FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| row = r.fetchone() | |||
| if row is None: | |||
| return None | |||
| else: | |||
| return json.loads(row[0]) | |||
| pass | |||
| pass | |||
| def get_learnware_use_flag(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| r = conn.execute(text("SELECT use_flag FROM tb_learnware WHERE id=:id;"), dict(id=id)) | |||
| row = r.fetchone() | |||
| if row is None: | |||
| return None | |||
| else: | |||
| return int(row[0]) | |||
| pass | |||
| pass | |||
| def load_market(self): | |||
| with self.engine.connect() as conn: | |||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | |||
| @@ -143,9 +143,11 @@ class EasyOrganizer(BaseOrganizer): | |||
| return False | |||
| zip_dir = self.learnware_zip_list[id] | |||
| os.remove(zip_dir) | |||
| if os.path.exists(zip_dir): | |||
| os.remove(zip_dir) | |||
| pass | |||
| folder_dir = self.learnware_folder_list[id] | |||
| rmtree(folder_dir) | |||
| rmtree(folder_dir, ignore_errors=True) | |||
| self.learnware_list.pop(id) | |||
| self.learnware_zip_list.pop(id) | |||
| self.learnware_folder_list.pop(id) | |||
| @@ -370,5 +372,24 @@ class EasyOrganizer(BaseOrganizer): | |||
| learnware_ids = self.get_learnware_ids(top, check_status) | |||
| return [self.learnware_list[idx] for idx in learnware_ids] | |||
| def reload_learnware(self, learnware_id: str): | |||
| current_learnware = self.learnware_list.get(learnware_id) | |||
| if current_learnware is None: | |||
| # add learnware | |||
| self.count += 1 | |||
| else: | |||
| pass | |||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) | |||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) | |||
| self.learnware_zip_list[learnware_id] = target_zip_dir | |||
| self.learnware_folder_list[learnware_id] = target_folder_dir | |||
| semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) | |||
| self.learnware_list[learnware_id] = get_learnware_from_dirpath( | |||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir) | |||
| self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) | |||
| pass | |||
| def __len__(self): | |||
| return len(self.learnware_list) | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import List | |||
| from ..easy2.organizer import EasyOrganizer | |||
| from ..easy.organizer import EasyOrganizer | |||
| from ...learnware import Learnware | |||
| from ...specification import BaseStatSpecification | |||
| from ...logger import get_module_logger | |||
| @@ -1,5 +1,5 @@ | |||
| from .base import LearnwareMarket | |||
| from .easy2 import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .hetergeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| MARKET_CONFIG = { | |||
| @@ -1,10 +0,0 @@ | |||
| ## How to Generate Environment Yaml | |||
| * create env config for conda: | |||
| ```shell | |||
| conda env export | grep -v "^prefix: " > environment.yml | |||
| ``` | |||
| * recover env from config | |||
| ``` | |||
| conda env create -f environment.yml | |||
| ``` | |||
| @@ -1,27 +0,0 @@ | |||
| name: learnware_example_env | |||
| channels: | |||
| - defaults | |||
| dependencies: | |||
| - _libgcc_mutex=0.1=main | |||
| - _openmp_mutex=5.1=1_gnu | |||
| - ca-certificates=2023.01.10=h06a4308_0 | |||
| - ld_impl_linux-64=2.38=h1181459_1 | |||
| - libffi=3.4.2=h6a678d5_6 | |||
| - libgcc-ng=11.2.0=h1234567_1 | |||
| - libgomp=11.2.0=h1234567_1 | |||
| - libstdcxx-ng=11.2.0=h1234567_1 | |||
| - ncurses=6.4=h6a678d5_0 | |||
| - openssl=1.1.1t=h7f8727e_0 | |||
| - pip=23.0.1=py38h06a4308_0 | |||
| - python=3.8.16=h7a1cb2a_3 | |||
| - readline=8.2=h5eee18b_0 | |||
| - setuptools=66.0.0=py38h06a4308_0 | |||
| - sqlite=3.41.2=h5eee18b_0 | |||
| - tk=8.6.12=h1ccaba5_0 | |||
| - wheel=0.38.4=py38h06a4308_0 | |||
| - xz=5.2.10=h5eee18b_1 | |||
| - zlib=1.2.13=h5eee18b_0 | |||
| - pip: | |||
| - joblib==1.2.0 | |||
| - learnware==0.0.1.99 | |||
| - numpy==1.19.5 | |||
| @@ -1,8 +0,0 @@ | |||
| model: | |||
| class_name: SVM | |||
| kwargs: {} | |||
| stat_specifications: | |||
| - module_path: learnware.specification | |||
| class_name: RKMETableSpecification | |||
| file_name: svm.json | |||
| kwargs: {} | |||
| @@ -1,20 +0,0 @@ | |||
| import os | |||
| import joblib | |||
| import numpy as np | |||
| from learnware.model import BaseModel | |||
| class SVM(BaseModel): | |||
| def __init__(self): | |||
| super(SVM, self).__init__(input_shape=(64,), output_shape=(10,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict_proba(X) | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| @@ -1,239 +0,0 @@ | |||
| import sys | |||
| import unittest | |||
| import os | |||
| import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| from sklearn import svm | |||
| from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_spec | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| }, | |||
| "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Education"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| } | |||
| class TestMarket(unittest.TestCase): | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| np.random.seed(2023) | |||
| learnware.init() | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | |||
| return easy_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||
| self.zip_path_list = [] | |||
| X, y = load_digits(return_X_y=True) | |||
| for i in range(learnware_num): | |||
| dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i)) | |||
| os.makedirs(dir_path, exist_ok=True) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| clf = svm.SVC(kernel="linear", probability=True) | |||
| clf.fit(data_X, data_y) | |||
| joblib.dump(clf, os.path.join(dir_path, "svm.pkl")) | |||
| spec = generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec.save(os.path.join(dir_path, "svm.json")) | |||
| init_file = os.path.join(dir_path, "__init__.py") | |||
| copyfile( | |||
| os.path.join(curr_root, "learnware_example/example_init.py"), init_file | |||
| ) # cp example_init.py init_file | |||
| yaml_file = os.path.join(dir_path, "learnware.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file | |||
| env_file = os.path.join(dir_path, "environment.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file) | |||
| zip_file = dir_path + ".zip" | |||
| # zip -q -r -j zip_file dir_path | |||
| with zipfile.ZipFile(zip_file, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(dir_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| rmtree(dir_path) # rm -r dir_path | |||
| self.zip_path_list.append(zip_file) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| easy_market = self._init_learnware_market() | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = { | |||
| "Dimension": 64, | |||
| "Description": { | |||
| f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit." | |||
| for i in range(64) | |||
| }, | |||
| } | |||
| semantic_spec["Output"] = { | |||
| "Dimension": 10, | |||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | |||
| } | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| easy_market.delete_learnware(learnware_id) | |||
| self.learnware_num -= 1 | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return easy_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| test_folder = os.path.join(curr_root, "test_semantics") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(test_folder): | |||
| rmtree(test_folder) | |||
| os.makedirs(test_folder, exist_ok=True) | |||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| for learnware in single_learnware_list: | |||
| print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_stat_search(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| assert len(single_learnware_list) == self.learnware_num, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_learnware_reuse(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| X, y = load_digits(return_X_y=True) | |||
| train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| stat_spec = generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | |||
| # Based on user information, the learnware market returns a list of learnwares (learnware_list) | |||
| # Use jobselector reuser to reuse the searched learnwares to make prediction | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| # Use ensemble pruning reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") | |||
| reuse_ensemble.fit(train_X[-200:], train_y[-200:]) | |||
| ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y)) | |||
| print("Averaging Reuser Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y)) | |||
| print("Ensemble Pruning Reuser Acc:", np.sum(ensemble_pruning_predict_y == data_y) / len(data_y)) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| _suite.addTest(TestMarket("test_learnware_reuse")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -11,14 +11,14 @@ from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import EasyMarket, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_spec | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Image"], "Type": "Class"}, | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| @@ -30,7 +30,7 @@ user_semantic = { | |||
| } | |||
| class TestAllWorkflow(unittest.TestCase): | |||
| class TestMarket(unittest.TestCase): | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| np.random.seed(2023) | |||
| @@ -38,7 +38,7 @@ class TestAllWorkflow(unittest.TestCase): | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = EasyMarket(market_id="sklearn_digits", rebuild=True) | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | |||
| return easy_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||
| @@ -86,16 +86,25 @@ class TestAllWorkflow(unittest.TestCase): | |||
| self.zip_path_list.append(zip_file) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=False): | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| easy_market = self._init_learnware_market() | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = { | |||
| "Dimension": 64, | |||
| "Description": { | |||
| f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit." | |||
| for i in range(64) | |||
| }, | |||
| } | |||
| semantic_spec["Output"] = { | |||
| "Dimension": 10, | |||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | |||
| @@ -103,21 +112,27 @@ class TestAllWorkflow(unittest.TestCase): | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| curr_inds = easy_market._get_ids() | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| easy_market.delete_learnware(learnware_id) | |||
| curr_inds = easy_market._get_ids() | |||
| self.learnware_num -= 1 | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return easy_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| test_folder = os.path.join(curr_root, "test_semantics") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| @@ -168,6 +183,7 @@ class TestAllWorkflow(unittest.TestCase): | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| assert len(single_learnware_list) == self.learnware_num, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| @@ -210,11 +226,11 @@ class TestAllWorkflow(unittest.TestCase): | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestAllWorkflow("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestAllWorkflow("test_upload_delete_learnware")) | |||
| _suite.addTest(TestAllWorkflow("test_search_semantics")) | |||
| _suite.addTest(TestAllWorkflow("test_stat_search")) | |||
| _suite.addTest(TestAllWorkflow("test_learnware_reuse")) | |||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| _suite.addTest(TestMarket("test_learnware_reuse")) | |||
| return _suite | |||