diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 2935a26..655e5a4 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -7,9 +7,7 @@ class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" def __init__(self, type: str): - """initilize the type of stats specification, current the type only supports the following values: - - 'table_rkme': the RKME specification for table dataset - + """initilize the type of stats specification Parameters ---------- type : str diff --git a/learnware/specification/table/rkme.py b/learnware/specification/table/rkme.py index 0c3613c..9769800 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/table/rkme.py @@ -51,7 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) - super(RKMEStatSpecification, self).__init__(type="table_rkme") + super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU.