diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 4d8b1b6..7655320 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec -from .base import Specification, StatSpecification +from .base import Specification, BaseStatSpecification from .rkme import RKMESpecification \ No newline at end of file diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 3d96fa0..c2060f3 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -10,17 +10,29 @@ class Specification: def get_property(self): return self.property + + def add_stat_spec(self, name, new_stat_spec:BaseStatSpecification): + self.stat_spec[name] = new_stat_spec + + def get_stat_spec_by_name(self, name:str): + if not name in self.stat_spec: + return None + + return self.stat_spec[name] def update_stat_spec(self): # update specification method pass -class StatSpecification: +class BaseStatSpecification: + def __init__(self): + pass + def generate_stat_spec_from_data(self, X: np.ndarray): raise NotImplementedError("generate_stat_spec_from_data is not implemented") - def save(self, filepath: str = "./stat_spec.npy"): + def save(self, filepath: str): raise NotImplementedError("save is not implemented") - def load(self, filepath: str = "./stat_spec.npy"): + def load(self, filepath: str): raise NotImplementedError("load is not implemented")