From c81fdc348809ea97bb64aa15ea2becdeef5862e9 Mon Sep 17 00:00:00 2001 From: chenzx Date: Thu, 30 Mar 2023 18:32:58 +0800 Subject: [PATCH] [MNT] Add methods for specification --- learnware/specification/__init__.py | 2 +- learnware/specification/base.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 4d8b1b6..7655320 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec -from .base import Specification, StatSpecification +from .base import Specification, BaseStatSpecification from .rkme import RKMESpecification \ No newline at end of file diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 3d96fa0..c2060f3 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -10,17 +10,29 @@ class Specification: def get_property(self): return self.property + + def add_stat_spec(self, name, new_stat_spec:BaseStatSpecification): + self.stat_spec[name] = new_stat_spec + + def get_stat_spec_by_name(self, name:str): + if not name in self.stat_spec: + return None + + return self.stat_spec[name] def update_stat_spec(self): # update specification method pass -class StatSpecification: +class BaseStatSpecification: + def __init__(self): + pass + def generate_stat_spec_from_data(self, X: np.ndarray): raise NotImplementedError("generate_stat_spec_from_data is not implemented") - def save(self, filepath: str = "./stat_spec.npy"): + def save(self, filepath: str): raise NotImplementedError("save is not implemented") - def load(self, filepath: str = "./stat_spec.npy"): + def load(self, filepath: str): raise NotImplementedError("load is not implemented")