|
|
|
@@ -20,8 +20,8 @@ try: |
|
|
|
except ImportError: |
|
|
|
_FAISS_INSTALLED = False |
|
|
|
|
|
|
|
from .base import BaseStatSpecification |
|
|
|
from ..logger import get_module_logger |
|
|
|
from ..base import BaseStatSpecification |
|
|
|
from ...logger import get_module_logger |
|
|
|
|
|
|
|
logger = get_module_logger("rkme") |
|
|
|
|
|
|
|
@@ -427,6 +427,7 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() |
|
|
|
rkme_to_save["beta"] = rkme_to_save["beta"].tolist() |
|
|
|
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" |
|
|
|
rkme_to_save["type"] = self.__class__.__name__ |
|
|
|
json.dump( |
|
|
|
rkme_to_save, |
|
|
|
codecs.open(save_path, "w", encoding="utf-8"), |