Browse Source

Merge pull request #29 from Learnware-LAMDA/dev_spec

[MNT] add "type" when saving specification and move rkme to table folder
tags/v0.3.2
bxdd GitHub 2 years ago
parent
commit
3c7eaa7fa2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 55 additions and 14 deletions
  1. +1
    -1
      README.md
  2. +2
    -2
      docs/start/client.rst
  3. +1
    -1
      docs/start/quick.rst
  4. +1
    -1
      docs/workflow/identify.rst
  5. +1
    -3
      examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py
  6. +1
    -1
      examples/workflow_by_code/main.py
  7. +1
    -1
      learnware/specification/__init__.py
  8. +9
    -0
      learnware/specification/base.py
  9. +1
    -0
      learnware/specification/table/__init__.py
  10. +4
    -2
      learnware/specification/table/rkme.py
  11. +1
    -1
      learnware/specification/utils.py
  12. +31
    -0
      tests/test_specification/test_rkme.py
  13. +1
    -1
      tests/test_workflow/test_workflow.py

+ 1
- 1
README.md View File

@@ -178,7 +178,7 @@ For example, the following code is designed to work with Reduced Set Kernel Embe
```python
import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json"))
user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}


+ 2
- 2
docs/start/client.rst View File

@@ -123,7 +123,7 @@ You can search learnware by providing a statistical specification. The statistic

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification()
@@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares
senarioes=[],
input_description={}, output_description={})

stat_spec = specification.rkme.RKMEStatSpecification()
stat_spec = specification.RKMEStatSpecification()
stat_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification()
specification.update_semantic_spec(semantic_spec)


+ 1
- 1
docs/start/quick.rst View File

@@ -170,7 +170,7 @@ For example, the code below executes learnware search when using Reduced Set Ker

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()

# unzip_path: directory for unzipped learnware zipfile
user_spec.load(os.path.join(unzip_path, "rkme.json"))


+ 1
- 1
docs/workflow/identify.rst View File

@@ -73,7 +73,7 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb

import learnware.specification as specification

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join("rkme.json"))
user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}


+ 1
- 3
examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py View File

@@ -85,9 +85,7 @@ def get_split_errs(algo):
split = train_xs.shape[0] - proportion_list[tmp]
model.fit(
train_xs[
split:,
],
train_xs[split:,],
train_ys[split:],
eval_set=[(val_xs, val_ys)],
early_stopping_rounds=50,


+ 1
- 1
examples/workflow_by_code/main.py View File

@@ -148,7 +148,7 @@ class LearnwareMarketWorkflow:
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
(


+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,3 +1,3 @@
from .utils import generate_stat_spec
from .base import Specification, BaseStatSpecification
from .rkme import RKMEStatSpecification
from .table import RKMEStatSpecification

+ 9
- 0
learnware/specification/base.py View File

@@ -6,6 +6,15 @@ from typing import Dict
class BaseStatSpecification:
"""The Statistical Specification Interface, which provide save and load method"""

def __init__(self, type: str):
"""initilize the type of stats specification
Parameters
----------
type : str
the type of the stats specification
"""
self.type = type

def generate_stat_spec_from_data(self, **kwargs):
"""Construct statistical specification from raw dataset
- kwargs may include the feature, label and model


+ 1
- 0
learnware/specification/table/__init__.py View File

@@ -0,0 +1 @@
from .rkme import RKMEStatSpecification

learnware/specification/rkme.py → learnware/specification/table/rkme.py View File

@@ -20,8 +20,8 @@ try:
except ImportError:
_FAISS_INSTALLED = False

from .base import BaseStatSpecification
from ..logger import get_module_logger
from ..base import BaseStatSpecification
from ...logger import get_module_logger

logger = get_module_logger("rkme")

@@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification):
torch.cuda.empty_cache()
self.device = choose_device(cuda_idx=cuda_idx)
setup_seed(0)
super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__)

def get_beta(self) -> np.ndarray:
"""Move beta(RKME weights) back to memory accessible to the CPU.
@@ -427,6 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy()
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"
rkme_to_save["type"] = self.type
json.dump(
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),

+ 1
- 1
learnware/specification/utils.py View File

@@ -4,7 +4,7 @@ import pandas as pd
from typing import Union

from .base import BaseStatSpecification
from .rkme import RKMEStatSpecification
from .table import RKMEStatSpecification
from ..config import C




+ 31
- 0
tests/test_specification/test_rkme.py View File

@@ -0,0 +1,31 @@
import os
import json
import unittest
import tempfile
import numpy as np

import learnware
import learnware.specification as specification
from learnware.specification import RKMEStatSpecification


class TestRKME(unittest.TestCase):
def test_rkme(self):
X = np.random.uniform(-10000, 10000, size=(5000, 200))
rkme = specification.utils.generate_rkme_spec(X)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
rkme.save(rkme_path)

with open(rkme_path, "r") as f:
data = json.load(f)
assert data["type"] == "RKMEStatSpecification"

rkme2 = RKMEStatSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMEStatSpecification"


if __name__ == "__main__":
unittest.main()

+ 1
- 1
tests/test_workflow/test_workflow.py View File

@@ -155,7 +155,7 @@ class TestAllWorkflow(unittest.TestCase):
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
(


Loading…
Cancel
Save