diff --git a/README.md b/README.md index 9b0069b..5629530 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ For example, the following code is designed to work with Reduced Set Kernel Embe ```python import learnware.specification as specification -user_spec = specification.rkme.RKMEStatSpecification() +user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} diff --git a/docs/start/client.rst b/docs/start/client.rst index b788c93..cbe1f2e 100644 --- a/docs/start/client.rst +++ b/docs/start/client.rst @@ -123,7 +123,7 @@ You can search learnware by providing a statistical specification. The statistic import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares senarioes=[], input_description={}, output_description={}) - stat_spec = specification.rkme.RKMEStatSpecification() + stat_spec = specification.RKMEStatSpecification() stat_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() specification.update_semantic_spec(semantic_spec) diff --git a/docs/start/quick.rst b/docs/start/quick.rst index 9d4c3ae..2140aaa 100644 --- a/docs/start/quick.rst +++ b/docs/start/quick.rst @@ -170,7 +170,7 @@ For example, the code below executes learnware search when using Reduced Set Ker import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() # unzip_path: directory for unzipped learnware zipfile user_spec.load(os.path.join(unzip_path, "rkme.json")) diff --git a/docs/workflow/identify.rst b/docs/workflow/identify.rst index e439e32..ffd7dbb 100644 --- a/docs/workflow/identify.rst +++ b/docs/workflow/identify.rst @@ -73,7 +73,7 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join("rkme.json")) user_info = BaseUserInfo( semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 5f69127..93a3fa3 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,9 +85,7 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[ - split:, - ], + train_xs[split:,], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 8473209..29d2e69 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -148,7 +148,7 @@ class LearnwareMarketWorkflow: with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) ( diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 556aefb..0bb0502 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 56c1ad9..655e5a4 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -6,6 +6,15 @@ from typing import Dict class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" + def __init__(self, type: str): + """initilize the type of stats specification + Parameters + ---------- + type : str + the type of the stats specification + """ + self.type = type + def generate_stat_spec_from_data(self, **kwargs): """Construct statistical specification from raw dataset - kwargs may include the feature, label and model diff --git a/learnware/specification/table/__init__.py b/learnware/specification/table/__init__.py new file mode 100644 index 0000000..dc94b1e --- /dev/null +++ b/learnware/specification/table/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMEStatSpecification diff --git a/learnware/specification/rkme.py b/learnware/specification/table/rkme.py similarity index 98% rename from learnware/specification/rkme.py rename to learnware/specification/table/rkme.py index 68c572f..9769800 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/table/rkme.py @@ -20,8 +20,8 @@ try: except ImportError: _FAISS_INSTALLED = False -from .base import BaseStatSpecification -from ..logger import get_module_logger +from ..base import BaseStatSpecification +from ...logger import get_module_logger logger = get_module_logger("rkme") @@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) + super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU. @@ -427,6 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + rkme_to_save["type"] = self.type json.dump( rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index c9a00be..c3693b7 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union from .base import BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification from ..config import C diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py new file mode 100644 index 0000000..4cdc246 --- /dev/null +++ b/tests/test_specification/test_rkme.py @@ -0,0 +1,31 @@ +import os +import json +import unittest +import tempfile +import numpy as np + +import learnware +import learnware.specification as specification +from learnware.specification import RKMEStatSpecification + + +class TestRKME(unittest.TestCase): + def test_rkme(self): + X = np.random.uniform(-10000, 10000, size=(5000, 200)) + rkme = specification.utils.generate_rkme_spec(X) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + rkme_path = os.path.join(tempdir, "rkme.json") + rkme.save(rkme_path) + + with open(rkme_path, "r") as f: + data = json.load(f) + assert data["type"] == "RKMEStatSpecification" + + rkme2 = RKMEStatSpecification() + rkme2.load(rkme_path) + assert rkme2.type == "RKMEStatSpecification" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 9aa3e86..1da7db3 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -155,7 +155,7 @@ class TestAllWorkflow(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) (