diff --git a/.github/workflow/install_learnware_with_source.yaml b/.github/workflow/install_learnware_with_source.yaml index bfe4e9c..e9589e3 100644 --- a/.github/workflow/install_learnware_with_source.yaml +++ b/.github/workflow/install_learnware_with_source.yaml @@ -1,4 +1,4 @@ -name: Test leanrnware from pip +name: Test leanrnware from source code on: push: diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 0bb0502..28100e7 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification -from .table import RKMEStatSpecification +from .regular import RKMEStatSpecification diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 655e5a4..1c340fa 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -15,11 +15,8 @@ class BaseStatSpecification: """ self.type = type - def generate_stat_spec_from_data(self, **kwargs): - """Construct statistical specification from raw dataset - - kwargs may include the feature, label and model - - kwargs also can include hyperparameters of specific method for specifaction generation - """ + def generate_stat_spec(self, **kwargs): + """Construct statistical specification""" raise NotImplementedError("generate_stat_spec_from_data is not implemented") def save(self, filepath: str): diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py new file mode 100644 index 0000000..64fe041 --- /dev/null +++ b/learnware/specification/regular/__init__.py @@ -0,0 +1 @@ +from .table import RKMEStatSpecification diff --git a/learnware/specification/regular/base.py b/learnware/specification/regular/base.py new file mode 100644 index 0000000..48a7e1f --- /dev/null +++ b/learnware/specification/regular/base.py @@ -0,0 +1,13 @@ +from ..base import BaseStatSpecification + + +class RegularStatsSpecification(BaseStatSpecification): + def generate_stat_spec(self, **kwargs): + self.generate_stat_spec_from_data(**kwargs) + + def generate_stat_spec_from_data(self, **kwargs): + """Construct statistical specification from raw dataset + - kwargs may include the feature, label and model + - kwargs also can include hyperparameters of specific method for specifaction generation + """ + raise NotImplementedError("generate_stat_spec_from_data is not implemented") diff --git a/learnware/specification/table/__init__.py b/learnware/specification/regular/table/__init__.py similarity index 100% rename from learnware/specification/table/__init__.py rename to learnware/specification/regular/table/__init__.py diff --git a/learnware/specification/table/rkme.py b/learnware/specification/regular/table/rkme.py similarity index 99% rename from learnware/specification/table/rkme.py rename to learnware/specification/regular/table/rkme.py index 9769800..32dea8d 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -20,8 +20,8 @@ try: except ImportError: _FAISS_INSTALLED = False -from ..base import BaseStatSpecification -from ...logger import get_module_logger +from ..base import RegularStatsSpecification +from ....logger import get_module_logger logger = get_module_logger("rkme") @@ -30,7 +30,7 @@ if not _FAISS_INSTALLED: logger.warning('Please run "conda install -c pytorch faiss-cpu" first.') -class RKMEStatSpecification(BaseStatSpecification): +class RKMEStatSpecification(RegularStatsSpecification): """Reduced Kernel Mean Embedding (RKME) Specification""" def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): diff --git a/learnware/specification/system/__init__.py b/learnware/specification/system/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/learnware/specification/system/heter_table.py b/learnware/specification/system/heter_table.py new file mode 100644 index 0000000..ae24a1c --- /dev/null +++ b/learnware/specification/system/heter_table.py @@ -0,0 +1,15 @@ +from ..base import BaseStatSpecification + + +class HeterMapTableSpecification(BaseStatSpecification): + def generate_stat_spec(self, **kwargs): + pass + + def save(self, filepath: str): + pass + + def load(self, filepath: str): + pass + + def dist(self, other_spec): + pass diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index c3693b7..5f1a689 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union from .base import BaseStatSpecification -from .table import RKMEStatSpecification +from .regular import RKMEStatSpecification from ..config import C