Browse Source

Merge pull request #32 from Learnware-LAMDA/refactor_specs

[MNT] Refactor Specification
tags/v0.3.2
bxdd GitHub 2 years ago
parent
commit
3f2e50c5ea
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 37 additions and 11 deletions
  1. +1
    -1
      .github/workflow/install_learnware_with_source.yaml
  2. +1
    -1
      learnware/specification/__init__.py
  3. +2
    -5
      learnware/specification/base.py
  4. +1
    -0
      learnware/specification/regular/__init__.py
  5. +13
    -0
      learnware/specification/regular/base.py
  6. +0
    -0
      learnware/specification/regular/table/__init__.py
  7. +3
    -3
      learnware/specification/regular/table/rkme.py
  8. +0
    -0
      learnware/specification/system/__init__.py
  9. +15
    -0
      learnware/specification/system/heter_table.py
  10. +1
    -1
      learnware/specification/utils.py

+ 1
- 1
.github/workflow/install_learnware_with_source.yaml View File

@@ -1,4 +1,4 @@
name: Test leanrnware from pip
name: Test leanrnware from source code

on:
push:


+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,3 +1,3 @@
from .utils import generate_stat_spec
from .base import Specification, BaseStatSpecification
from .table import RKMEStatSpecification
from .regular import RKMEStatSpecification

+ 2
- 5
learnware/specification/base.py View File

@@ -15,11 +15,8 @@ class BaseStatSpecification:
"""
self.type = type

def generate_stat_spec_from_data(self, **kwargs):
"""Construct statistical specification from raw dataset
- kwargs may include the feature, label and model
- kwargs also can include hyperparameters of specific method for specifaction generation
"""
def generate_stat_spec(self, **kwargs):
"""Construct statistical specification"""
raise NotImplementedError("generate_stat_spec_from_data is not implemented")

def save(self, filepath: str):


+ 1
- 0
learnware/specification/regular/__init__.py View File

@@ -0,0 +1 @@
from .table import RKMEStatSpecification

+ 13
- 0
learnware/specification/regular/base.py View File

@@ -0,0 +1,13 @@
from ..base import BaseStatSpecification


class RegularStatsSpecification(BaseStatSpecification):
def generate_stat_spec(self, **kwargs):
self.generate_stat_spec_from_data(**kwargs)

def generate_stat_spec_from_data(self, **kwargs):
"""Construct statistical specification from raw dataset
- kwargs may include the feature, label and model
- kwargs also can include hyperparameters of specific method for specifaction generation
"""
raise NotImplementedError("generate_stat_spec_from_data is not implemented")

learnware/specification/table/__init__.py → learnware/specification/regular/table/__init__.py View File


learnware/specification/table/rkme.py → learnware/specification/regular/table/rkme.py View File

@@ -20,8 +20,8 @@ try:
except ImportError:
_FAISS_INSTALLED = False

from ..base import BaseStatSpecification
from ...logger import get_module_logger
from ..base import RegularStatsSpecification
from ....logger import get_module_logger

logger = get_module_logger("rkme")

@@ -30,7 +30,7 @@ if not _FAISS_INSTALLED:
logger.warning('Please run "conda install -c pytorch faiss-cpu" first.')


class RKMEStatSpecification(BaseStatSpecification):
class RKMEStatSpecification(RegularStatsSpecification):
"""Reduced Kernel Mean Embedding (RKME) Specification"""

def __init__(self, gamma: float = 0.1, cuda_idx: int = -1):

+ 0
- 0
learnware/specification/system/__init__.py View File


+ 15
- 0
learnware/specification/system/heter_table.py View File

@@ -0,0 +1,15 @@
from ..base import BaseStatSpecification


class HeterMapTableSpecification(BaseStatSpecification):
def generate_stat_spec(self, **kwargs):
pass

def save(self, filepath: str):
pass

def load(self, filepath: str):
pass

def dist(self, other_spec):
pass

+ 1
- 1
learnware/specification/utils.py View File

@@ -4,7 +4,7 @@ import pandas as pd
from typing import Union

from .base import BaseStatSpecification
from .table import RKMEStatSpecification
from .regular import RKMEStatSpecification
from ..config import C




Loading…
Cancel
Save