|
|
|
@@ -20,8 +20,8 @@ try: |
|
|
|
except ImportError: |
|
|
|
_FAISS_INSTALLED = False |
|
|
|
|
|
|
|
from .base import BaseStatSpecification |
|
|
|
from ..logger import get_module_logger |
|
|
|
from ..base import BaseStatSpecification |
|
|
|
from ...logger import get_module_logger |
|
|
|
|
|
|
|
logger = get_module_logger("rkme") |
|
|
|
|
|
|
|
@@ -51,6 +51,7 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
torch.cuda.empty_cache() |
|
|
|
self.device = choose_device(cuda_idx=cuda_idx) |
|
|
|
setup_seed(0) |
|
|
|
super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) |
|
|
|
|
|
|
|
def get_beta(self) -> np.ndarray: |
|
|
|
"""Move beta(RKME weights) back to memory accessible to the CPU. |
|
|
|
@@ -427,6 +428,7 @@ class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() |
|
|
|
rkme_to_save["beta"] = rkme_to_save["beta"].tolist() |
|
|
|
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" |
|
|
|
rkme_to_save["type"] = self.type |
|
|
|
json.dump( |
|
|
|
rkme_to_save, |
|
|
|
codecs.open(save_path, "w", encoding="utf-8"), |