|
|
|
@@ -20,8 +20,8 @@ try: |
|
|
|
except ImportError: |
|
|
|
_FAISS_INSTALLED = False |
|
|
|
|
|
|
|
from ..base import BaseStatSpecification |
|
|
|
from ...logger import get_module_logger |
|
|
|
from ..base import RegularStatsSpecification |
|
|
|
from ....logger import get_module_logger |
|
|
|
|
|
|
|
logger = get_module_logger("rkme") |
|
|
|
|
|
|
|
@@ -30,7 +30,7 @@ if not _FAISS_INSTALLED: |
|
|
|
logger.warning('Please run "conda install -c pytorch faiss-cpu" first.') |
|
|
|
|
|
|
|
|
|
|
|
class RKMEStatSpecification(BaseStatSpecification): |
|
|
|
class RKMEStatSpecification(RegularStatsSpecification): |
|
|
|
"""Reduced Kernel Mean Embedding (RKME) Specification""" |
|
|
|
|
|
|
|
def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): |