Browse Source

[MNT] modift stats spec type

tags/v0.3.2
bxdd 2 years ago
parent
commit
e0c43041cf
2 changed files with 2 additions and 4 deletions
  1. +1
    -3
      learnware/specification/base.py
  2. +1
    -1
      learnware/specification/table/rkme.py

+ 1
- 3
learnware/specification/base.py View File

@@ -7,9 +7,7 @@ class BaseStatSpecification:
"""The Statistical Specification Interface, which provide save and load method"""

def __init__(self, type: str):
"""initilize the type of stats specification, current the type only supports the following values:
- 'table_rkme': the RKME specification for table dataset

"""initilize the type of stats specification
Parameters
----------
type : str


+ 1
- 1
learnware/specification/table/rkme.py View File

@@ -51,7 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification):
torch.cuda.empty_cache()
self.device = choose_device(cuda_idx=cuda_idx)
setup_seed(0)
super(RKMEStatSpecification, self).__init__(type="table_rkme")
super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__)

def get_beta(self) -> np.ndarray:
"""Move beta(RKME weights) back to memory accessible to the CPU.


Loading…
Cancel
Save