diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 556aefb..0bb0502 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification diff --git a/learnware/specification/table/__init__.py b/learnware/specification/table/__init__.py new file mode 100644 index 0000000..8c56c8e --- /dev/null +++ b/learnware/specification/table/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMEStatSpecification \ No newline at end of file diff --git a/learnware/specification/rkme.py b/learnware/specification/table/rkme.py similarity index 99% rename from learnware/specification/rkme.py rename to learnware/specification/table/rkme.py index 68c572f..59ac436 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/table/rkme.py @@ -20,8 +20,8 @@ try: except ImportError: _FAISS_INSTALLED = False -from .base import BaseStatSpecification -from ..logger import get_module_logger +from ..base import BaseStatSpecification +from ...logger import get_module_logger logger = get_module_logger("rkme") @@ -427,6 +427,7 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + rkme_to_save["type"] = self.__class__.__name__ json.dump( rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index c9a00be..c3693b7 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union from .base import BaseStatSpecification -from .rkme import RKMEStatSpecification +from .table import RKMEStatSpecification from ..config import C