From a56efc51c24ac85142bbd7ff24aaf0bb2762a86e Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 13:58:46 +0800 Subject: [PATCH 1/8] [MNT] add type in save() and move rkme to table folder --- learnware/specification/__init__.py | 2 +- learnware/specification/table/__init__.py | 1 + learnware/specification/{ => table}/rkme.py | 5 +++-- learnware/specification/utils.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 learnware/specification/table/__init__.py rename learnware/specification/{ => table}/rkme.py (99%) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 556aefb..0bb0502 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification diff --git a/learnware/specification/table/__init__.py b/learnware/specification/table/__init__.py new file mode 100644 index 0000000..8c56c8e --- /dev/null +++ b/learnware/specification/table/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMEStatSpecification \ No newline at end of file diff --git a/learnware/specification/rkme.py b/learnware/specification/table/rkme.py similarity index 99% rename from learnware/specification/rkme.py rename to learnware/specification/table/rkme.py index 68c572f..59ac436 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/table/rkme.py @@ -20,8 +20,8 @@ try: except ImportError: _FAISS_INSTALLED = False -from .base import BaseStatSpecification -from ..logger import get_module_logger +from ..base import BaseStatSpecification +from ...logger import get_module_logger logger = get_module_logger("rkme") @@ -427,6 +427,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + rkme_to_save["type"] = self.__class__.__name__ json.dump( rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index c9a00be..c3693b7 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union from .base import BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification from ..config import C From 395e0f3e3cd12fc7fe20b92a62f4ee76de1ca381 Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 14:04:45 +0800 Subject: [PATCH 2/8] [ENH] add test for RKMEStatSpecification --- tests/test_specification/test_rkme.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_specification/test_rkme.py diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py new file mode 100644 index 0000000..2f9614b --- /dev/null +++ b/tests/test_specification/test_rkme.py @@ -0,0 +1,30 @@ +import os +import json +import unittest +import tempfile +import numpy as np + +import learnware +import learnware.specification as specification +from learnware.specification import RKMEStatSpecification + + +class TestRKME(unittest.TestCase): + def test_rkme(self): + X = np.random.uniform(-10000, 10000, size=(5000, 200)) + rkme = specification.utils.generate_rkme_spec(X) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + rkme_path = os.path.join(tempdir, "rkme.json") + rkme.save(rkme_path) + + with open(rkme_path, "r") as f: + data = json.load(f) + assert data["type"] == "RKMEStatSpecification" + + rkme2 = RKMEStatSpecification() + rkme2.load(rkme_path) + + +if __name__ == "__main__": + unittest.main() From 14efe1658c834f077006a61f90cc6125c727524f Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 14:05:57 +0800 Subject: [PATCH 3/8] [FIX] change rkme.RKMEStatSpecification into RKMEStatSpecification --- README.md | 2 +- docs/start/client.rst | 4 ++-- docs/start/quick.rst | 2 +- docs/workflow/identify.rst | 2 +- examples/workflow_by_code/main.py | 2 +- tests/test_workflow/test_workflow.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9b0069b..5629530 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ For example, the following code is designed to work with Reduced Set Kernel Embe ```python import learnware.specification as specification -user_spec = specification.rkme.RKMEStatSpecification() +user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} diff --git a/docs/start/client.rst b/docs/start/client.rst index b788c93..cbe1f2e 100644 --- a/docs/start/client.rst +++ b/docs/start/client.rst @@ -123,7 +123,7 @@ You can search learnware by providing a statistical specification. The statistic import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares senarioes=[], input_description={}, output_description={}) - stat_spec = specification.rkme.RKMEStatSpecification() + stat_spec = specification.RKMEStatSpecification() stat_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() specification.update_semantic_spec(semantic_spec) diff --git a/docs/start/quick.rst b/docs/start/quick.rst index 9d4c3ae..2140aaa 100644 --- a/docs/start/quick.rst +++ b/docs/start/quick.rst @@ -170,7 +170,7 @@ For example, the code below executes learnware search when using Reduced Set Ker import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() # unzip_path: directory for unzipped learnware zipfile user_spec.load(os.path.join(unzip_path, "rkme.json")) diff --git a/docs/workflow/identify.rst b/docs/workflow/identify.rst index e439e32..ffd7dbb 100644 --- a/docs/workflow/identify.rst +++ b/docs/workflow/identify.rst @@ -73,7 +73,7 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb import learnware.specification as specification - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join("rkme.json")) user_info = BaseUserInfo( semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 8473209..29d2e69 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -148,7 +148,7 @@ class LearnwareMarketWorkflow: with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) ( diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 9aa3e86..1da7db3 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -155,7 +155,7 @@ class TestAllWorkflow(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.RKMEStatSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) ( From f0aec7fb1312b7afb994c24ea66910656336cffb Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 14:06:16 +0800 Subject: [PATCH 4/8] [MNT] format code by black --- learnware/specification/table/__init__.py | 2 +- tests/test_specification/test_rkme.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/learnware/specification/table/__init__.py b/learnware/specification/table/__init__.py index 8c56c8e..dc94b1e 100644 --- a/learnware/specification/table/__init__.py +++ b/learnware/specification/table/__init__.py @@ -1 +1 @@ -from .rkme import RKMEStatSpecification \ No newline at end of file +from .rkme import RKMEStatSpecification diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 2f9614b..27a9033 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -13,15 +13,15 @@ class TestRKME(unittest.TestCase): def test_rkme(self): X = np.random.uniform(-10000, 10000, size=(5000, 200)) rkme = specification.utils.generate_rkme_spec(X) - + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: rkme_path = os.path.join(tempdir, "rkme.json") rkme.save(rkme_path) - + with open(rkme_path, "r") as f: data = json.load(f) assert data["type"] == "RKMEStatSpecification" - + rkme2 = RKMEStatSpecification() rkme2.load(rkme_path) From a0352f42f7f07eb6e16fab1f415437b3b6dcb0dd Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 30 Oct 2023 15:43:04 +0800 Subject: [PATCH 5/8] [MNT] add type attr for stat specification --- examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py | 4 +--- learnware/specification/base.py | 3 +++ learnware/specification/table/rkme.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 5f69127..93a3fa3 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,9 +85,7 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[ - split:, - ], + train_xs[split:,], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 56c1ad9..b689d41 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -6,6 +6,9 @@ from typing import Dict class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" + def __init__(self, type): + self.type = type + def generate_stat_spec_from_data(self, **kwargs): """Construct statistical specification from raw dataset - kwargs may include the feature, label and model diff --git a/learnware/specification/table/rkme.py b/learnware/specification/table/rkme.py index 59ac436..0c3613c 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/table/rkme.py @@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) + super(RKMEStatSpecification, self).__init__(type="table_rkme") def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU. @@ -427,7 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - rkme_to_save["type"] = self.__class__.__name__ + rkme_to_save["type"] = self.type json.dump( rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), From 00156b72beb668ac196dc29bf7ec8cd7c36a88e8 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 30 Oct 2023 15:45:02 +0800 Subject: [PATCH 6/8] [DOC] add docstring --- learnware/specification/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/learnware/specification/base.py b/learnware/specification/base.py index b689d41..2935a26 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -6,7 +6,15 @@ from typing import Dict class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" - def __init__(self, type): + def __init__(self, type: str): + """initilize the type of stats specification, current the type only supports the following values: + - 'table_rkme': the RKME specification for table dataset + + Parameters + ---------- + type : str + the type of the stats specification + """ self.type = type def generate_stat_spec_from_data(self, **kwargs): From e0c43041cfec418ebba3f5af3e9935df2d1b10fd Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 30 Oct 2023 15:58:26 +0800 Subject: [PATCH 7/8] [MNT] modift stats spec type --- learnware/specification/base.py | 4 +--- learnware/specification/table/rkme.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 2935a26..655e5a4 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -7,9 +7,7 @@ class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" def __init__(self, type: str): - """initilize the type of stats specification, current the type only supports the following values: - - 'table_rkme': the RKME specification for table dataset - + """initilize the type of stats specification Parameters ---------- type : str diff --git a/learnware/specification/table/rkme.py b/learnware/specification/table/rkme.py index 0c3613c..9769800 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/table/rkme.py @@ -51,7 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) - super(RKMEStatSpecification, self).__init__(type="table_rkme") + super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU. From ba7df7b6d04e14b646e360936190380b8432ca8e Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 30 Oct 2023 16:05:55 +0800 Subject: [PATCH 8/8] [MNT] fix test --- tests/test_specification/test_rkme.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 27a9033..4cdc246 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -24,6 +24,7 @@ class TestRKME(unittest.TestCase): rkme2 = RKMEStatSpecification() rkme2.load(rkme_path) + assert rkme2.type == "RKMEStatSpecification" if __name__ == "__main__":