From e0c43041cfec418ebba3f5af3e9935df2d1b10fd Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 30 Oct 2023 15:58:26 +0800 Subject: [PATCH] [MNT] modift stats spec type --- learnware/specification/base.py | 4 +--- learnware/specification/table/rkme.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 2935a26..655e5a4 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -7,9 +7,7 @@ class BaseStatSpecification: """The Statistical Specification Interface, which provide save and load method""" def __init__(self, type: str): - """initilize the type of stats specification, current the type only supports the following values: - - 'table_rkme': the RKME specification for table dataset - + """initilize the type of stats specification Parameters ---------- type : str diff --git a/learnware/specification/table/rkme.py b/learnware/specification/table/rkme.py index 0c3613c..9769800 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/table/rkme.py @@ -51,7 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) - super(RKMEStatSpecification, self).__init__(type="table_rkme") + super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU.