Browse Source

[MNT] add type in save() and move rkme to table folder

tags/v0.3.2
Gene 2 years ago
parent
commit
a56efc51c2
4 changed files with 6 additions and 4 deletions
  1. +1
    -1
      learnware/specification/__init__.py
  2. +1
    -0
      learnware/specification/table/__init__.py
  3. +3
    -2
      learnware/specification/table/rkme.py
  4. +1
    -1
      learnware/specification/utils.py

+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,3 +1,3 @@
from .utils import generate_stat_spec
from .base import Specification, BaseStatSpecification
from .rkme import RKMEStatSpecification
from .table import RKMEStatSpecification

+ 1
- 0
learnware/specification/table/__init__.py View File

@@ -0,0 +1 @@
from .rkme import RKMEStatSpecification

learnware/specification/rkme.py → learnware/specification/table/rkme.py View File

@@ -20,8 +20,8 @@ try:
except ImportError:
_FAISS_INSTALLED = False

from .base import BaseStatSpecification
from ..logger import get_module_logger
from ..base import BaseStatSpecification
from ...logger import get_module_logger

logger = get_module_logger("rkme")

@@ -427,6 +427,7 @@ class RKMEStatSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy()
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"
rkme_to_save["type"] = self.__class__.__name__
json.dump(
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),

+ 1
- 1
learnware/specification/utils.py View File

@@ -4,7 +4,7 @@ import pandas as pd
from typing import Union

from .base import BaseStatSpecification
from .rkme import RKMEStatSpecification
from .table import RKMEStatSpecification
from ..config import C




Loading…
Cancel
Save