diff --git a/.github/workflows/install_learnware_with_pip.yaml b/.github/workflows/install_learnware_with_pip.yaml index 137909e..4b4f86a 100644 --- a/.github/workflows/install_learnware_with_pip.yaml +++ b/.github/workflows/install_learnware_with_pip.yaml @@ -13,8 +13,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest] - python-version: [3.8, 3.9, 3.10] + os: [ubuntu-20.04] + python-version: [3.8, 3.9] steps: - name: Test learnware from pip @@ -25,7 +25,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Add conda to system path + - name: Add conda to system path run: | # $CONDA is an environment variable pointing to the root of the miniconda directory echo $CONDA/bin >> $GITHUB_PATH @@ -33,7 +33,7 @@ jobs: - name: Create conda env for macos run: | conda create -n learnware python=${{ matrix.python-version }} - conda create activate learnware + conda activate learnware - name: Update pip to the latest version run: | diff --git a/.github/workflows/install_learnware_with_source.yaml b/.github/workflows/install_learnware_with_source.yaml index e9589e3..d0fb6af 100644 --- a/.github/workflows/install_learnware_with_source.yaml +++ b/.github/workflows/install_learnware_with_source.yaml @@ -13,8 +13,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest] - python-version: [3.8, 3.9, 3.10] + os: [ubuntu-20.04] + python-version: [3.8, 3.9] steps: - name: Test learnware from pip @@ -25,7 +25,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Add conda to system path + - name: Add conda to system path run: | # $CONDA is an environment variable pointing to the root of the miniconda directory echo $CONDA/bin >> $GITHUB_PATH @@ -33,7 +33,7 @@ jobs: - name: Create conda env for macos run: | conda create -n learnware python=${{ matrix.python-version }} - conda create activate learnware + conda activate learnware - name: Update pip to the latest version run: | diff --git a/README.md b/README.md index 5629530..a01c2d7 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ is composed of the following four parts. - ``learnware.yaml`` - A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and + A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and the file name of your statistical specification file. - ``environment.yaml`` @@ -178,10 +178,10 @@ For example, the following code is designed to work with Reduced Set Kernel Embe ```python import learnware.specification as specification -user_spec = specification.RKMEStatSpecification() +user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/references/api.rst b/docs/references/api.rst index a2f723b..20de7bc 100644 --- a/docs/references/api.rst +++ b/docs/references/api.rst @@ -50,7 +50,7 @@ Specification .. autoclass:: learnware.specification.BaseStatSpecification :members: -.. autoclass:: learnware.specification.RKMEStatSpecification +.. autoclass:: learnware.specification.RKMETableSpecification :members: Model diff --git a/docs/start/client.rst b/docs/start/client.rst index cbe1f2e..664fdf1 100644 --- a/docs/start/client.rst +++ b/docs/start/client.rst @@ -117,13 +117,13 @@ You can search learnwares in official market using semantic specification. All t Statistical Specification Search --------------------------------- -You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMEStatSpecification`: +You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMETableSpecification`: .. code-block:: python import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() @@ -138,7 +138,7 @@ You can search learnware by providing a statistical specification. The statistic Combine Semantic and Statistical Search ---------------------------------------- -You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMEStatSpecification`: +You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMETableSpecification`: .. code-block:: python @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares senarioes=[], input_description={}, output_description={}) - stat_spec = specification.RKMEStatSpecification() + stat_spec = specification.RKMETableSpecification() stat_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() specification.update_semantic_spec(semantic_spec) diff --git a/docs/start/quick.rst b/docs/start/quick.rst index 2140aaa..6d8a7a8 100644 --- a/docs/start/quick.rst +++ b/docs/start/quick.rst @@ -47,7 +47,7 @@ includes the following four components: - ``learnware.yaml`` - A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMEStatSpecification`` for Reduced Kernel Mean Embedding), and + A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMETableSpecification`` for Reduced Kernel Mean Embedding), and the file name of your statistical specification file. - ``environment.yaml`` or ``requirements.txt`` @@ -170,12 +170,12 @@ For example, the code below executes learnware search when using Reduced Set Ker import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() # unzip_path: directory for unzipped learnware zipfile user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/workflow/identify.rst b/docs/workflow/identify.rst index ffd7dbb..ed7bb55 100644 --- a/docs/workflow/identify.rst +++ b/docs/workflow/identify.rst @@ -73,10 +73,10 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join("rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/workflow/submit.rst b/docs/workflow/submit.rst index 928e108..fe097c3 100644 --- a/docs/workflow/submit.rst +++ b/docs/workflow/submit.rst @@ -94,7 +94,7 @@ guaranteeing the security and privacy of your local original data. ------------------ Additionally, you are asked to prepare a configuration file in YAML format. -The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and +The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of how your learnware configuration file should be structured, based on our previous discussion: @@ -105,7 +105,7 @@ how your learnware configuration file should be structured, based on our previou kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: stat.json kwargs: {} diff --git a/examples/dataset_image_workflow/example_files/example_yaml.yaml b/examples/dataset_image_workflow/example_files/example_yaml.yaml index 6ca01c9..9aaf820 100644 --- a/examples/dataset_image_workflow/example_files/example_yaml.yaml +++ b/examples/dataset_image_workflow/example_files/example_yaml.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMEImageSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 26e639a..74e125b 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -1,11 +1,15 @@ import numpy as np import torch +from tqdm import tqdm + from get_data import * import os import random + +from learnware.specification.image import RKMEImageSpecification +from learnware.reuse.averaging import AveragingReuser from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction from learnware.learnware import Learnware -from learnware.reuse import JobSelectorReuser, AveragingReuser import time from learnware.market import EasyMarket, BaseUserInfo @@ -23,7 +27,7 @@ processed_data_root = "./data/processed_data" tmp_dir = "./data/tmp" learnware_pool_dir = "./data/learnware_pool" dataset = "cifar10" -n_uploaders = 50 +n_uploaders = 30 n_users = 20 n_classes = 10 data_root = os.path.join(origin_data_root, dataset) @@ -45,6 +49,7 @@ semantic_specs = [ "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, + "Output": {"Dimension": 10}, } ] @@ -88,9 +93,15 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo tmp_init_path = os.path.join(save_root, "__init__.py") tmp_model_file_path = os.path.join(save_root, "model.py") mmodel_file_path = "./example_files/model.py" + + # Computing the specification from the whole dataset is too costly. X = np.load(data_path) + indices = np.random.choice(len(X), size=2000, replace=False) + X_sampled = X[indices] + st = time.time() - user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + user_spec = RKMEImageSpecification(cuda_idx=0) + user_spec.generate_stat_spec_from_data(X=X_sampled) ed = time.time() logger.info("Stat spec generated in %.3f s" % (ed - st)) user_spec.save(tmp_spec_path) @@ -117,7 +128,7 @@ def prepare_market(): except: pass os.makedirs(learnware_pool_dir, exist_ok=True) - for i in range(n_uploaders): + for i in tqdm(range(n_uploaders), total=n_uploaders, desc="Preparing..."): data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i)) model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) init_file_path = "./example_files/example_init.py" @@ -148,40 +159,38 @@ def test_search(gamma=0.1, load_market=True): improve_list = [] job_selector_score_list = [] ensemble_score_list = [] - for i in range(n_users): + for i in tqdm(range(n_users), total=n_users, desc="Searching..."): user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i)) user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) user_data = np.load(user_data_path) user_label = np.load(user_label_path) - user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) - logger.info("Searching Market for user: %d" % (i)) + user_stat_spec = RKMEImageSpecification(cuda_idx=0) + user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) + logger.info("Searching Market for user: %d" % i) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( user_info ) - l = len(sorted_score_list) acc_list = [] - for idx in range(l): - learnware = single_learnware_list[idx] - score = sorted_score_list[idx] + for idx, (score, learnware) in enumerate(zip(sorted_score_list[:5], single_learnware_list[:5])): pred_y = learnware.predict(user_data) acc = eval_prediction(pred_y, user_label) acc_list.append(acc) - logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) + logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) # test reuse (job selector) - reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) - reuse_predict = reuse_baseline.predict(user_data=user_data) - reuse_score = eval_prediction(reuse_predict, user_label) - job_selector_score_list.append(reuse_score) - print(f"mixture reuse loss: {reuse_score}") + # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) + # reuse_predict = reuse_baseline.predict(user_data=user_data) + # reuse_score = eval_prediction(reuse_predict, user_label) + # job_selector_score_list.append(reuse_score) + # print(f"mixture reuse loss: {reuse_score}") # test reuse (ensemble) - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") + reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_score = eval_prediction(ensemble_predict_y, user_label) ensemble_score_list.append(ensemble_score) - print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n") + print(f"reuse accuracy (vote_by_prob): {ensemble_score}\n") select_list.append(acc_list[0]) avg_list.append(np.mean(acc_list)) @@ -191,17 +200,17 @@ def test_search(gamma=0.1, load_market=True): "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f" % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) ) - logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) - logger.info( - "Average Job Selector Reuse Performance: %.3f +/- %.3f" - % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) - ) logger.info( "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) if __name__ == "__main__": + logger.info("=" * 40) + logger.info(f"n_uploaders:\t{n_uploaders}") + logger.info(f"n_users:\t{n_users}") + logger.info("=" * 40) + prepare_data() prepare_model() test_search(load_market=False) diff --git a/examples/dataset_m5_workflow/example.yaml b/examples/dataset_m5_workflow/example.yaml index 6ca01c9..cd539c8 100644 --- a/examples/dataset_m5_workflow/example.yaml +++ b/examples/dataset_m5_workflow/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_m5_workflow/main.py b/examples/dataset_m5_workflow/main.py index a720b30..009b557 100644 --- a/examples/dataset_m5_workflow/main.py +++ b/examples/dataset_m5_workflow/main.py @@ -144,7 +144,7 @@ class M5DatasetWorkflow: user_spec_path = f"./user_spec/user_{idx}.json" user_spec.save(user_spec_path) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/examples/dataset_pfs_workflow/example.yaml b/examples/dataset_pfs_workflow/example.yaml index 6ca01c9..cd539c8 100644 --- a/examples/dataset_pfs_workflow/example.yaml +++ b/examples/dataset_pfs_workflow/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index b3a7d36..b5cbdd8 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -142,7 +142,7 @@ class PFSDatasetWorkflow: user_spec_path = f"./user_spec/user_{idx}.json" user_spec.save(user_spec_path) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/examples/workflow_by_code/learnware_example/example.yaml b/examples/workflow_by_code/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/examples/workflow_by_code/learnware_example/example.yaml +++ b/examples/workflow_by_code/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 29d2e69..2f62db0 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -148,9 +148,9 @@ class LearnwareMarketWorkflow: with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, @@ -175,7 +175,7 @@ class LearnwareMarketWorkflow: _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index 32ef7bd..d3dd704 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -37,7 +37,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: "stat_specifications": [ { "module_path": "learnware.specification", - "class_name": "RKMEStatSpecification", + "class_name": "RKMETableSpecification", "file_name": "stat_spec.json", "kwargs": {}, }, diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 591cf05..957efda 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -18,7 +18,7 @@ from .. import utils from ..config import C as conf from ..logger import get_module_logger from ..learnware import Learnware, get_learnware_from_dirpath -from ..specification import RKMEStatSpecification, Specification +from ..specification import RKMETableSpecification, Specification logger = get_module_logger("market", "INFO") @@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket): pass # check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") @@ -296,7 +296,7 @@ class EasyMarket(LearnwareMarket): def _calculate_rkme_spec_mixture_weight( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[List[float], float]: @@ -306,7 +306,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] A list of existing learnwares - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket): """ learnware_num = len(learnware_list) RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] if type(intermediate_K) == np.ndarray: @@ -365,7 +365,7 @@ class EasyMarket(LearnwareMarket): def _calculate_intermediate_K_and_C( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray]: @@ -375,7 +375,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares up till now - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket): """ num = intermediate_K.shape[0] - 1 RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) @@ -400,7 +400,7 @@ class EasyMarket(LearnwareMarket): def _search_by_rkme_spec_mixture_auto( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, weight_cutoff: float = 0.98, ) -> Tuple[float, List[float], List[Learnware]]: @@ -410,7 +410,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -488,7 +488,7 @@ class EasyMarket(LearnwareMarket): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -496,7 +496,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification Returns @@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -518,7 +518,7 @@ class EasyMarket(LearnwareMarket): def _search_by_rkme_spec_mixture_greedy( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, score_cutoff: float = 0.001, ) -> Tuple[float, List[float], List[Learnware]]: @@ -528,7 +528,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -588,7 +588,7 @@ class EasyMarket(LearnwareMarket): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -596,7 +596,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification user RKME statistical specification Returns @@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket): both lists are sorted by mmd dist """ RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] mmd_dist_list = [] for RKME in RKME_list: @@ -819,12 +819,12 @@ class EasyMarket(LearnwareMarket): # if len(learnware_list) == 0: learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) - if "RKMEStatSpecification" not in user_info.stat_info: + if "RKMETableSpecification" not in user_info.stat_info: return None, learnware_list, 0.0, None elif len(learnware_list) == 0: return [], [], 0.0, [] else: - user_rkme = user_info.stat_info["RKMEStatSpecification"] + user_rkme = user_info.stat_info["RKMETableSpecification"] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 70a7c15..f369725 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -80,7 +80,7 @@ class EasyStatisticalChecker(BaseChecker): if is_text: stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETextStatSpecification") else: - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None and not is_text: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 830b5d3..18f67eb 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -1,28 +1,15 @@ import os -import json import copy -import torch import zipfile -import traceback import tempfile -import numpy as np -import pandas as pd -from rapidfuzz import fuzz -from cvxopt import solvers, matrix from shutil import copyfile, rmtree -from typing import Tuple, Any, List, Union, Dict +from typing import Tuple, List, Union from .database_ops import DatabaseOperations -from ..base import LearnwareMarket, BaseUserInfo - - -from ... import utils +from ..base import BaseOrganizer, BaseChecker from ...config import C as conf from ...logger import get_module_logger from ...learnware import Learnware, get_learnware_from_dirpath -from ...specification import RKMEStatSpecification, Specification - -from ..base import BaseOrganizer, BaseChecker from ...logger import get_module_logger logger = get_module_logger("easy_organizer") diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 93307e7..13038d2 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -2,12 +2,12 @@ import torch import numpy as np from rapidfuzz import fuzz from cvxopt import solvers, matrix -from typing import Tuple, List +from typing import Tuple, List, Union from .organizer import EasyOrganizer from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMEStatSpecification +from ...specification import RKMETableSpecification, RKMEImageSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -188,7 +188,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): return final_result -class EasyTableSearcher(BaseSearcher): +class EasyStatSearcher(BaseSearcher): def _convert_dist_to_score( self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 ) -> List[float]: @@ -227,7 +227,7 @@ class EasyTableSearcher(BaseSearcher): def _calculate_rkme_spec_mixture_weight( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[List[float], float]: @@ -237,7 +237,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] A list of existing learnwares - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -294,7 +294,7 @@ class EasyTableSearcher(BaseSearcher): def _calculate_intermediate_K_and_C( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray]: @@ -304,7 +304,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares up till now - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -327,7 +327,7 @@ class EasyTableSearcher(BaseSearcher): def _search_by_rkme_spec_mixture_auto( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, weight_cutoff: float = 0.98, ) -> Tuple[float, List[float], List[Learnware]]: @@ -337,7 +337,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -415,7 +415,7 @@ class EasyTableSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -423,7 +423,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] User RKME statistical specification Returns @@ -447,7 +447,7 @@ class EasyTableSearcher(BaseSearcher): def _search_by_rkme_spec_mixture_greedy( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, score_cutoff: float = 0.001, ) -> Tuple[float, List[float], List[Learnware]]: @@ -457,7 +457,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -517,7 +517,7 @@ class EasyTableSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -525,7 +525,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] user RKME statistical specification Returns @@ -557,7 +557,7 @@ class EasyTableSearcher(BaseSearcher): if "RKMETextStatSpecification" in user_info.stat_info: self.stat_info_name = "RKMETextStatSpecification" else: - self.stat_info_name = "RKMEStatSpecification" + self.stat_info_name = "RKMETableSpecification" user_rkme = user_info.stat_info[self.stat_info_name] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") @@ -599,12 +599,12 @@ class EasySearcher(BaseSearcher): def __init__(self, organizer: EasyOrganizer = None): super(EasySearcher, self).__init__(organizer) self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) - self.table_searcher = EasyTableSearcher(organizer) + self.stat_searcher = EasyStatSearcher(organizer) def reset(self, organizer): self.learnware_oganizer = organizer self.semantic_searcher.reset(organizer) - self.table_searcher.reset(organizer) + self.stat_searcher.reset(organizer) def __call__( self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" @@ -631,9 +631,14 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] +<<<<<<< HEAD elif "RKMEStatSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) elif "RKMETextStatSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) +======= + elif "RKMETableSpecification" in user_info.stat_info: + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) +>>>>>>> b0aaae48e77fb5d49d2b7a1c31a2023580ea2115 else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index c1acab0..84d0bb7 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score from learnware.learnware import Learnware import learnware.specification as specification from .base import BaseReuser -from ..specification import RKMEStatSpecification, RKMETextStatSpecification +from ..specification import RKMETableSpecification, RKMETextStatSpecification from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -165,7 +165,7 @@ class JobSelectorReuser(BaseReuser): return job_select_result def _calculate_rkme_spec_mixture_weight( - self, user_data: np.ndarray, task_rkme_list: List[RKMEStatSpecification], task_rkme_matrix: np.ndarray + self, user_data: np.ndarray, task_rkme_list: List[RKMETableSpecification], task_rkme_matrix: np.ndarray ) -> List[float]: """_summary_ @@ -173,7 +173,7 @@ class JobSelectorReuser(BaseReuser): ---------- user_data : np.ndarray Raw user data. - task_rkme_list : List[RKMEStatSpecification] + task_rkme_list : List[RKMETableSpecification] The list of learwares' rkmes whose mixture approximates the user's rkme task_rkme_matrix : np.ndarray Inner product matrix calculated from task_rkme_list. diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index a8d9ace..bd94261 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ -from .utils import generate_stat_spec +from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RKMEStatSpecification, RKMETextStatSpecification +from .regular import RegularStatsSpecification, RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index ca69ff4..1731a2f 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,2 +1,4 @@ -from .table import RKMEStatSpecification from .text import RKMETextStatSpecification +from .table import RKMETableSpecification, RKMEStatSpecification +from .image import RKMEImageSpecification +from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/base.py b/learnware/specification/regular/base.py index 48a7e1f..6916177 100644 --- a/learnware/specification/regular/base.py +++ b/learnware/specification/regular/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ..base import BaseStatSpecification diff --git a/learnware/specification/regular/image/__init__.py b/learnware/specification/regular/image/__init__.py new file mode 100644 index 0000000..0a18ded --- /dev/null +++ b/learnware/specification/regular/image/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMEImageSpecification diff --git a/learnware/specification/regular/image/cnn_gp.py b/learnware/specification/regular/image/cnn_gp.py new file mode 100644 index 0000000..6ceb7f6 --- /dev/null +++ b/learnware/specification/regular/image/cnn_gp.py @@ -0,0 +1,304 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + + +__all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "ConvKP", "NonlinKP") +""" +With this package, we are able to accurately and efficiently compute the kernel matrix corresponding to the NNGP during the search phase. + +Github Repository: https://github.com/cambridge-mlg/cnn-gp + +References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen. Deep Convolutional Networks as shallow Gaussian Processes. In: International Conference on Learning Representations (ICLR'19), 2019. +""" + + +class NNGPKernel(nn.Module): + """ + Transforms one kernel matrix into another. + [N1, N2, W, H] -> [N1, N2, W, H] + """ + + def forward(self, x, y=None, same=None, diag=False): + """ + Either takes one minibatch (x), or takes two minibatches (x and y), and + a boolean indicating whether they're the same. + """ + if y is None: + assert same is None + y = x + same = True + + assert not diag or len(x) == len(y), "diagonal kernels must operate with data of equal length" + + assert 4 == len(x.size()) + assert 4 == len(y.size()) + assert x.size(1) == y.size(1) + assert x.size(2) == y.size(2) + assert x.size(3) == y.size(3) + + N1 = x.size(0) + N2 = y.size(0) + C = x.size(1) + W = x.size(2) + H = x.size(3) + + # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H] + if diag: + xy = (x * y).mean(1, keepdim=True) + else: + xy = (x.unsqueeze(1) * y).mean(2).view(N1 * N2, 1, W, H) + xx = (x**2).mean(1, keepdim=True) + yy = (y**2).mean(1, keepdim=True) + + initial_kp = ConvKP(same, diag, xy, xx, yy) + final_kp = self.propagate(initial_kp) + r = NonlinKP(final_kp).xy + if diag: + return r.view(N1) + else: + return r.view(N1, N2) + + +class Conv2d(NNGPKernel): + def __init__( + self, + kernel_size, + stride=1, + padding="same", + dilation=1, + var_weight=1.0, + var_bias=0.0, + in_channel_multiplier=1, + out_channel_multiplier=1, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.var_weight = var_weight + self.var_bias = var_bias + self.kernel_has_row_of_zeros = False + if padding == "same": + self.padding = dilation * (kernel_size // 2) + if kernel_size % 2 == 0: + self.kernel_has_row_of_zeros = True + else: + self.padding = padding + + if self.kernel_has_row_of_zeros: + # We need to pad one side larger than the other. We just make a + # kernel that is slightly too large and make its last column and + # row zeros. + kernel = t.ones(1, 1, self.kernel_size + 1, self.kernel_size + 1) + kernel[:, :, 0, :] = 0.0 + kernel[:, :, :, 0] = 0.0 + else: + kernel = t.ones(1, 1, self.kernel_size, self.kernel_size) + self.register_buffer("kernel", kernel * (self.var_weight / self.kernel_size**2)) + self.in_channel_multiplier, self.out_channel_multiplier = (in_channel_multiplier, out_channel_multiplier) + + def propagate(self, kp): + kp = ConvKP(kp) + + def f(patch): + return ( + F.conv2d(patch, self.kernel, stride=self.stride, padding=self.padding, dilation=self.dilation) + + self.var_bias + ) + + return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy)) + + def nn(self, channels, in_channels=None, out_channels=None): + if in_channels is None: + in_channels = channels + if out_channels is None: + out_channels = channels + conv2d = nn.Conv2d( + in_channels=in_channels * self.in_channel_multiplier, + out_channels=out_channels * self.out_channel_multiplier, + kernel_size=self.kernel_size + (1 if self.kernel_has_row_of_zeros else 0), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=(self.var_bias > 0.0), + ) + conv2d.weight.data.normal_(0, math.sqrt(self.var_weight / conv2d.in_channels) / self.kernel_size) + if self.kernel_has_row_of_zeros: + conv2d.weight.data[:, :, 0, :] = 0 + conv2d.weight.data[:, :, :, 0] = 0 + if self.var_bias > 0.0: + conv2d.bias.data.normal_(0, math.sqrt(self.var_bias)) + return conv2d + + def layers(self): + return 1 + + +class ReLU(NNGPKernel): + """ + A ReLU nonlinearity, the covariance is numerically stabilised by clamping + values. + """ + + f32_tiny = np.finfo(np.float32).tiny + + def propagate(self, kp): + kp = NonlinKP(kp) + """ + We need to calculate (xy, xx, yy == c, v₁, v₂): + ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤ + √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂) + + which is equivalent to: + 1/2π ( √(v₁v₂ - c²) + (π - θ)c ) + + # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2) + """ + xx_yy = kp.xx * kp.yy + self.f32_tiny + + # Clamp these so the outputs are not NaN + cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1) + sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0)) + theta = t.acos(cos_theta) + xy = (sin_theta + (math.pi - theta) * kp.xy) / (2 * math.pi) + + xx = kp.xx / 2.0 + if kp.same: + yy = xx + if kp.diag: + xy = xx + else: + # Make sure the diagonal agrees with `xx` + eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device) + xy = (1 - eye) * xy + eye * xx + else: + yy = kp.yy / 2.0 + return NonlinKP(kp.same, kp.diag, xy, xx, yy) + + def nn(self, channels, in_channels=None, out_channels=None): + assert in_channels is None + assert out_channels is None + return nn.ReLU() + + def layers(self): + return 0 + + +class Sequential(NNGPKernel): + def __init__(self, *mods): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + + def propagate(self, kp): + for mod in self.mods: + kp = mod.propagate(kp) + return kp + + def nn(self, channels, in_channels=None, out_channels=None): + if len(self.mods) == 0: + return nn.Sequential() + elif len(self.mods) == 1: + return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels) + else: + return nn.Sequential( + self.mods[0].nn(channels, in_channels=in_channels), + *[mod.nn(channels) for mod in self.mods[1:-1]], + self.mods[-1].nn(channels, out_channels=out_channels) + ) + + def layers(self): + return sum(mod.layers() for mod in self.mods) + + +class KernelPatch: + """ + Represents a block of the kernel matrix. + Critically, we need the variances of the rows and columns, even if the + diagonal isn't part of the block, and this introduces considerable + complexity. + In particular, we also need to know whether the + rows and columns of the matrix correspond, in which case, we need to do + something different when we add IID noise. + """ + + def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None): + if isinstance(same_or_kp, KernelPatch): + same = same_or_kp.same + diag = same_or_kp.diag + xy = same_or_kp.xy + xx = same_or_kp.xx + yy = same_or_kp.yy + else: + same = same_or_kp + + self.Nx = xx.size(0) + self.Ny = yy.size(0) + self.W = xy.size(-2) + self.H = xy.size(-1) + + self.init(same, diag, xy, xx, yy) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __add__(self, other): + return self._do_elementwise(other, "__add__") + + def __mul__(self, other): + return self._do_elementwise(other, "__mul__") + + def _do_elementwise(self, other, op): + KP = type(self) + if isinstance(other, KernelPatch): + other = KP(other) + assert self.same == other.same + assert self.diag == other.diag + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other.xy), + getattr(self.xx, op)(other.xx), + getattr(self.yy, op)(other.yy), + ) + else: + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other), + getattr(self.xx, op)(other), + getattr(self.yy, op)(other), + ) + + +class ConvKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx * self.Ny, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + + +class NonlinKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx, self.Ny, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, self.W, self.H) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py new file mode 100644 index 0000000..1f05382 --- /dev/null +++ b/learnware/specification/regular/image/rkme.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +import codecs +import copy +import functools +import json +import os + +from typing import Any + +import numpy as np +import torch +import torch_optimizer +from torch import nn +from torch.utils.data import TensorDataset, DataLoader +from torchvision.transforms import Resize +from tqdm import tqdm + +from . import cnn_gp +from ..base import BaseStatSpecification +from ..table.rkme import solve_qp, choose_device, setup_seed + + +class RKMEImageSpecification(BaseStatSpecification): + # INNER_PRODUCT_COUNT = 0 + IMAGE_WIDTH = 32 + + def __init__(self, cuda_idx: int = -1, **kwargs): + """Initializing RKME Image specification's parameters. + + Parameters + ---------- + cuda_idx : int + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. + """ + self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. + + self.z = None + self.beta = None + self.cuda_idx = cuda_idx + self.device = choose_device(cuda_idx=cuda_idx) + self.cache = False + + self.n_models = kwargs["n_models"] if "n_models" in kwargs else 16 + self.model_config = ( + {"k": 2, "mu": 0, "sigma": None, "net_width": 128, "net_depth": 3} + if "model_config" not in kwargs + else kwargs["model_config"] + ) + + setup_seed(0) + super(RKMEImageSpecification, self).__init__(type=self.__class__.__name__) + + def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): + model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) + + def __builder(i): + if fixed_seed is not None: + torch.manual_seed(fixed_seed[i]) + return model_class().to(self.device) + + return (__builder(m) for m in range(n_models)) + + def generate_stat_spec_from_data( + self, + X: np.ndarray, + K: int = 50, + step_size: float = 0.01, + steps: int = 100, + resize: bool = True, + nonnegative_beta: bool = True, + reduce: bool = True, + verbose: bool = True, + **kwargs, + ): + """Construct reduced set from raw dataset using iterative optimization. + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in [N, C, H, W] format. + K : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + resize : bool + Whether to scale the image to the requested size, by default True. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + verbose : bool, optional + Whether to print training progress, by default True + Returns + ------- + + """ + if ( + X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH + ) and not resize: + raise ValueError( + "X should be in shape of [N, C, {0:d}, {0:d}]. " + "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( + RKMEImageSpecification.IMAGE_WIDTH + ) + ) + + if not torch.is_tensor(X): + X = torch.from_numpy(X) + X = X.to(self.device).float() + + X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan + if torch.any(torch.isnan(X)): + for i, img in enumerate(X): + is_nan = torch.isnan(img) + if torch.any(is_nan): + if torch.all(is_nan): + raise ValueError(f"All values in image {i} are exceptional, e.g., NaN and Inf.") + img_mean = torch.nanmean(img) + X[i] = torch.where(is_nan, img_mean, img) + + if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: + X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X) + + num_points = X.shape[0] + X_shape = X.shape + Z_shape = tuple([K] + list(X_shape)[1:]) + + X_train = (X - torch.mean(X, [0, 2, 3], keepdim=True)) / (torch.std(X, [0, 2, 3], keepdim=True)) + + if X_train.shape[1] > 1 and ("whitening" not in kwargs or kwargs["whitening"]): + whitening = _get_zca_matrix(X_train) + X_train = X_train.reshape(num_points, -1) @ whitening + X_train = X_train.view(*X_shape) + + if not reduce: + self.beta = 1 / num_points * np.ones(num_points) + self.z = torch.to(self.device) + self.beta = torch.from_numpy(self.beta).to(self.device) + return + + random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) + self.z = torch.zeros(Z_shape).to(self.device).float().normal_(0, 1) + with torch.no_grad(): + x_features = self._generate_random_feature(X_train, random_models=random_models) + self._update_beta(x_features, nonnegative_beta, random_models=random_models) + + optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) + + for _ in tqdm(range(steps)) if verbose else range(steps): + # Regenerate Random Models + random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) + + with torch.no_grad(): + x_features = self._generate_random_feature(X_train, random_models=random_models) + self._update_z(x_features, optimizer, random_models=random_models) + self._update_beta(x_features, nonnegative_beta, random_models=random_models) + + @torch.no_grad() + def _update_beta(self, x_features: Any, nonnegative_beta: bool = True, random_models=None): + Z = self.z + if not torch.is_tensor(Z): + Z = torch.from_numpy(Z) + Z = Z.to(self.device) + + if not torch.is_tensor(x_features): + x_features = torch.from_numpy(x_features) + x_features = x_features.to(self.device) + + z_features = self._generate_random_feature(Z, random_models=random_models) + K = self._calc_nngp_from_feature(z_features, z_features).to(self.device) + C = self._calc_nngp_from_feature(z_features, x_features).to(self.device) + C = torch.sum(C, dim=1) / x_features.shape[0] + + if nonnegative_beta: + beta = solve_qp(K.double(), C.double()).to(self.device) + else: + beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C + + self.beta = beta + + def _update_z(self, x_features: Any, optimizer, random_models=None): + Z = self.z + beta = self.beta + + if not torch.is_tensor(Z): + Z = torch.from_numpy(Z) + Z = Z.to(self.device).float() + + if not torch.is_tensor(beta): + beta = torch.from_numpy(beta) + beta = beta.to(self.device) + + if not torch.is_tensor(x_features): + x_features = torch.from_numpy(x_features) + x_features = x_features.to(self.device).float() + + with torch.no_grad(): + beta = beta.unsqueeze(0) + + for i in range(3): + z_features = self._generate_random_feature(Z, random_models=random_models) + K_z = self._calc_nngp_from_feature(z_features, z_features) + K_zx = self._calc_nngp_from_feature(x_features, z_features) + term_1 = torch.sum(K_z * (beta.T @ beta)) + term_2 = torch.sum(K_zx * beta / x_features.shape[0]) + loss = term_1 - 2 * term_2 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def _generate_random_feature(self, data_X, data_Y=None, batch_size=4096, random_models=None): + X_features_list, Y_features_list = [], [] + + dataset_X, dataset_Y = TensorDataset(data_X), None + dataloader_X, dataloader_Y = DataLoader(dataset_X, batch_size=batch_size, shuffle=True), None + if data_Y is not None: + dataset_Y = TensorDataset(data_Y) + dataloader_Y = DataLoader(dataset_Y, batch_size=batch_size, shuffle=True) + assert data_X.shape[1] == data_Y.shape[1] + + for m, model in enumerate( + random_models if random_models else self._generate_models(n_models=self.n_models, channel=data_X.shape[1]) + ): + model.eval() + + curr_features_list = [] + for i, (X,) in enumerate(dataloader_X): + out = model(X) + curr_features_list.append(out) + curr_features = torch.cat(curr_features_list, 0) + X_features_list.append(curr_features) + + if data_Y is not None: + curr_features_list = [] + for i, (Y,) in enumerate(dataloader_Y): + out = model(Y) + curr_features_list.append(out) + curr_features = torch.cat(curr_features_list, 0) + Y_features_list.append(curr_features) + + X_features = torch.cat(X_features_list, 1) + X_features = X_features / torch.sqrt(torch.asarray(X_features.shape[1], device=self.device)) + if data_Y is None: + return X_features + else: + Y_features = torch.cat(Y_features_list, 1) + Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) + return X_features, Y_features + + def inner_prod(self, Phi2: RKMEImageSpecification) -> float: + """Compute the inner product between two RKME Image specifications + + Parameters + ---------- + Phi2 : RKMEImageSpecification + The other RKME Image specification. + + Returns + ------- + float + The inner product between two RKME Image specifications. + """ + v = self._inner_prod_nngp(Phi2) + return v + + def _inner_prod_nngp(self, Phi2: RKMEImageSpecification) -> float: + beta_1 = self.beta.reshape(1, -1).detach().to(self.device) + beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) + + Z1 = self.z.to(self.device) + Z2 = Phi2.z.to(self.device) + + kernel_fn = _build_ConvNet_NNGP(channel=Z1.shape[1], **self.model_config).to(self.device) + if id(self) == id(Phi2): + K_zz = kernel_fn(Z1) + else: + K_zz = kernel_fn(Z1, Z2) + v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() + + # RKMEImageSpecification.INNER_PRODUCT_COUNT += 1 + return v + + def dist(self, Phi2: RKMEImageSpecification, omit_term1: bool = False) -> float: + """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications + + Parameters + ---------- + Phi2 : RKMEImageSpecification + The other RKME specification. + omit_term1 : bool, optional + True if the inner product of self with itself can be omitted, by default False. + """ + + if omit_term1: + term1 = 0 + else: + term1 = self.inner_prod(self) + term2 = self.inner_prod(Phi2) + term3 = Phi2.inner_prod(Phi2) + + v = float(term1 - 2 * term2 + term3) + + return v + + @staticmethod + def _calc_nngp_from_feature(x1_feature: torch.Tensor, x2_feature: torch.Tensor): + K_12 = x1_feature @ x2_feature.T + 0.01 + return K_12 + + def herding(self, T: int) -> np.ndarray: + raise NotImplementedError("The function herding hasn't been supported in Image RKME Specification.") + + def _sampling_candidates(self, N: int) -> np.ndarray: + raise NotImplementedError() + + def get_beta(self) -> np.ndarray: + return self.beta.detach().cpu().numpy() + + def get_z(self) -> np.ndarray: + return self.z.detach().cpu().numpy() + + def save(self, filepath: str): + """Save the computed RKME Image specification to a specified path in JSON format. + + Parameters + ---------- + filepath : str + The specified saving path. + """ + save_path = filepath + rkme_to_save = copy.deepcopy(self.__dict__) + if torch.is_tensor(rkme_to_save["z"]): + rkme_to_save["z"] = rkme_to_save["z"].detach().cpu().numpy() + rkme_to_save["z"] = rkme_to_save["z"].tolist() + if torch.is_tensor(rkme_to_save["beta"]): + rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() + rkme_to_save["beta"] = rkme_to_save["beta"].tolist() + rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(rkme_to_save, fout, separators=(",", ":")) + + def load(self, filepath: str) -> bool: + """Load a RKME Image specification file in JSON format from the specified path. + + Parameters + ---------- + filepath : str + The specified loading path. + + Returns + ------- + bool + True if the RKME is loaded successfully. + """ + # Load JSON file: + load_path = filepath + if os.path.exists(load_path): + with codecs.open(load_path, "r", encoding="utf-8") as fin: + obj_text = fin.read() + rkme_load = json.loads(obj_text) + rkme_load["device"] = choose_device(rkme_load["cuda_idx"]) + rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"], dtype="float32")) + rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"], dtype="float64")) + + for d in self.__dir__(): + if d in rkme_load.keys(): + setattr(self, d, rkme_load[d]) + + self.beta = self.beta.to(self.device) + self.z = self.z.to(self.device) + + return True + else: + return False + + +def _get_zca_matrix(X, reg_coef=0.1): + X_flat = X.reshape(X.shape[0], -1) + cov = (X_flat.T @ X_flat) / X_flat.shape[0] + reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] + u, s, _ = torch.svd(cov + reg_amount * torch.eye(cov.shape[0]).to(X.device)) + inv_sqrt_zca_eigs = s ** (-0.5) + whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u) + + return whitening_transform + + +class _ConvNet_wide(nn.Module): + def __init__(self, channel, mu=None, sigma=None, k=2, net_width=128, net_depth=3, im_size=(32, 32)): + self.k = k + super().__init__() + self.features, shape_feat = self._make_layers(channel, net_width, net_depth, im_size, mu, sigma) + # self.aggregation = nn.AvgPool2d(kernel_size=shape_feat[1]) + + def forward(self, x): + out = self.features(x) + out = out.reshape(out.size(0), -1) + # out = self.aggregation(out).reshape(out.size(0), -1) + return out + + def _make_layers(self, channel, net_width, net_depth, im_size, mu, sigma): + k = self.k + + layers = [] + in_channels = channel + shape_feat = [in_channels, im_size[0], im_size[1]] + for d in range(net_depth): + layers += [_build_conv2d_gaussian(in_channels, int(k * net_width), 3, 1, mean=mu, std=sigma)] + shape_feat[0] = int(k * net_width) + + layers += [nn.ReLU(inplace=True)] + in_channels = int(k * net_width) + + layers += [nn.AvgPool2d(kernel_size=2, stride=2)] + shape_feat[1] //= 2 + shape_feat[2] //= 2 + + return nn.Sequential(*layers), shape_feat + + +def _build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean=None, std=None): + layer = nn.Conv2d(in_channels, out_channels, kernel, padding=padding) + if mean is None: + mean = 0 + if std is None: + std = np.sqrt(2) / np.sqrt(layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]) + # print('Initializing Conv. Mean=%.2f, std=%.2f'%(mean, std)) + torch.nn.init.normal_(layer.weight, mean, std) + torch.nn.init.normal_(layer.bias, 0, 0.1) + return layer + + +def _build_ConvNet_NNGP(channel, k=2, net_width=128, net_depth=3, kernel_size=3, im_size=(32, 32), **kwargs): + layers = [] + for d in range(net_depth): + layers += [cnn_gp.Conv2d(kernel_size=kernel_size, padding="same", var_bias=0.1, var_weight=np.sqrt(2))] + # /np.sqrt(kernel_size * kernel_size * channel) + layers += [cnn_gp.ReLU()] + # AvgPooling + layers += [cnn_gp.Conv2d(kernel_size=2, padding=0, stride=2)] + + assert im_size[0] % (2**net_depth) == 0 + layers.append(cnn_gp.Conv2d(kernel_size=im_size[0] // (2**net_depth), padding=0)) + + return cnn_gp.Sequential(*layers) diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index dc94b1e..19fa956 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1 +1 @@ -from .rkme import RKMEStatSpecification +from .rkme import RKMETableSpecification, RKMEStatSpecification diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 32dea8d..82c81a2 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -26,11 +26,12 @@ from ....logger import get_module_logger logger = get_module_logger("rkme") if not _FAISS_INSTALLED: - logger.warning("Required faiss version >= 1.7.1 is not detected!") - logger.warning('Please run "conda install -c pytorch faiss-cpu" first.') + logger.warning( + "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" + ) -class RKMEStatSpecification(RegularStatsSpecification): +class RKMETableSpecification(RegularStatsSpecification): """Reduced Kernel Mean Embedding (RKME) Specification""" def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): @@ -51,7 +52,7 @@ class RKMEStatSpecification(RegularStatsSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) - super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) + super(RKMETableSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU. @@ -334,12 +335,12 @@ class RKMEStatSpecification(RegularStatsSpecification): else: logger.warning("Not enough candidates for herding!") - def inner_prod(self, Phi2: RKMEStatSpecification) -> float: + def inner_prod(self, Phi2: RKMETableSpecification) -> float: """Compute the inner product between two RKME specifications Parameters ---------- - Phi2 : RKMEStatSpecification + Phi2 : RKMETableSpecification The other RKME specification. Returns @@ -355,12 +356,12 @@ class RKMEStatSpecification(RegularStatsSpecification): return float(v) - def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float: + def dist(self, Phi2: RKMETableSpecification, omit_term1: bool = False) -> float: """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications Parameters ---------- - Phi2 : RKMEStatSpecification + Phi2 : RKMETableSpecification The other RKME specification. omit_term1 : bool, optional True if the inner product of self with itself can be omitted, by default False. @@ -428,12 +429,8 @@ class RKMEStatSpecification(RegularStatsSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - rkme_to_save["type"] = self.type - json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), - ) + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(rkme_to_save, fout, separators=(",", ":")) def load(self, filepath: str) -> bool: """Load a RKME specification file in JSON format from the specified path. @@ -466,6 +463,14 @@ class RKMEStatSpecification(RegularStatsSpecification): return False +class RKMEStatSpecification(RKMETableSpecification): + """nickname for RKMETableSpecification, for compatibility currently. + TODO: modify all learnware in database and remove this nickname + """ + + pass + + def setup_seed(seed): """Fix a random seed for addressing reproducibility issues. diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index 19371b8..dcafb69 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union, List from .base import BaseStatSpecification -from .regular import RKMEStatSpecification, RKMETextStatSpecification +from .regular import RKMETableSpecification, RKMEImageSpecification, RKMETextStatSpecification from ..config import C @@ -42,10 +42,10 @@ def generate_rkme_spec( nonnegative_beta: bool = True, reduce: bool = True, cuda_idx: int = None, -) -> RKMEStatSpecification: +) -> RKMETableSpecification: """ Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification. - Return a RKMEStatSpecification object, use .save() method to save as json file. + Return a RKMETableSpecification object, use .save() method to save as json file. Parameters ---------- @@ -73,8 +73,8 @@ def generate_rkme_spec( Returns ------- - RKMEStatSpecification - A RKMEStatSpecification object + RKMETableSpecification + A RKMETableSpecification object """ # Convert data type X = convert_to_numpy(X) @@ -94,10 +94,73 @@ def generate_rkme_spec( cuda_idx = 0 # Generate rkme spec - rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx) + rkme_spec = RKMETableSpecification(gamma=gamma, cuda_idx=cuda_idx) rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) return rkme_spec +def generate_rkme_image_spec( + X: Union[np.ndarray, torch.Tensor], + reduced_set_size: int = 50, + step_size: float = 0.01, + steps: int = 100, + resize: bool = True, + nonnegative_beta: bool = True, + reduce: bool = True, + verbose: bool = True, + cuda_idx: int = None, +) -> RKMEImageSpecification: + """ + Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. + Return a RKMEImageSpecification object, use .save() method to save as json file. + + Parameters + ---------- + X : np.ndarray, or torch.Tensor + Raw data in np.ndarray, or torch.Tensor format. + The shape of X: [N, C, H, W] + N: Number of images. + C: Number of channels. + H: Height of images. + W: Width of images.s + For example, if X has shape (100, 3, 32, 32), it means there are 100 samples, and each sample is a 3-channel (RGB) image of size 32x32. + reduced_set_size : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + resize : bool + Whether to scale the image to the requested size, by default True. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + cuda_idx : int + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. + None indicates that CUDA is automatically selected. + verbose : bool, optional + Whether to print training progress, by default True + + Returns + ------- + RKMEImageSpecification + A RKMEImageSpecification object + """ + + # Check cuda_idx + if not torch.cuda.is_available() or cuda_idx == -1: + cuda_idx = -1 + else: + num_cuda_devices = torch.cuda.device_count() + if cuda_idx is None or not (0 <= cuda_idx < num_cuda_devices): + cuda_idx = 0 + + # Generate rkme spec + rkme_image_spec = RKMEImageSpecification(cuda_idx=cuda_idx) + rkme_image_spec.generate_stat_spec_from_data( + X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose + ) + return rkme_image_spec def generate_rkme_text_spec( X: List[str], @@ -156,6 +219,8 @@ def generate_rkme_text_spec( return rkme_text_spec + + def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification: """ Interface for users to generate statistical specification. diff --git a/setup.py b/setup.py index 552c4f5..8119fd9 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ REQUIRED = [ "rapidfuzz>=3.4.0", "torchtext>=0.16.0", "sentence_transformers>=2.2.2", + "torch-optimizer>=0.3.0", ] if get_platform() != MACOS: diff --git a/tests/test_market/learnware_example/example.yaml b/tests/test_market/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/tests/test_market/learnware_example/example.yaml +++ b/tests/test_market/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/tests/test_market/test_easy.py b/tests/test_market/test_easy.py index 5f22729..16729e2 100644 --- a/tests/test_market/test_easy.py +++ b/tests/test_market/test_easy.py @@ -170,9 +170,9 @@ class TestMarket(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.rkme.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index bedd64e..680e668 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -2,19 +2,22 @@ import os import json import string import random +import torch import unittest import tempfile import numpy as np -import learnware import learnware.specification as specification from learnware.specification import RKMEStatSpecification, RKMETextStatSpecification +from learnware.specification import RKMETableSpecification, RKMEImageSpecification +from learnware.specification import generate_rkme_image_spec, generate_rkme_spec class TestRKME(unittest.TestCase): def test_rkme(self): X = np.random.uniform(-10000, 10000, size=(5000, 200)) - rkme = specification.utils.generate_rkme_spec(X) + rkme = generate_rkme_spec(X) + rkme.generate_stat_spec_from_data(X) with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: rkme_path = os.path.join(tempdir, "rkme.json") @@ -22,11 +25,35 @@ class TestRKME(unittest.TestCase): with open(rkme_path, "r") as f: data = json.load(f) - assert data["type"] == "RKMEStatSpecification" + assert data["type"] == "RKMETableSpecification" - rkme2 = RKMEStatSpecification() + rkme2 = RKMETableSpecification() rkme2.load(rkme_path) - assert rkme2.type == "RKMEStatSpecification" + assert rkme2.type == "RKMETableSpecification" + + def test_image_rkme(self): + def _test_image_rkme(X): + image_rkme = generate_rkme_image_spec(X, steps=10) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + rkme_path = os.path.join(tempdir, "rkme.json") + image_rkme.save(rkme_path) + + with open(rkme_path, "r") as f: + data = json.load(f) + assert data["type"] == "RKMEImageSpecification" + + rkme2 = RKMEImageSpecification() + rkme2.load(rkme_path) + assert rkme2.type == "RKMEImageSpecification" + + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)) / 255) + + _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32))) + _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128))) + _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255) def test_text_rkme(self): def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): diff --git a/tests/test_workflow/learnware_example/example.yaml b/tests/test_workflow/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/tests/test_workflow/learnware_example/example.yaml +++ b/tests/test_workflow/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 1da7db3..f4507c5 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -155,9 +155,9 @@ class TestAllWorkflow(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, @@ -182,7 +182,7 @@ class TestAllWorkflow(unittest.TestCase): train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)