|
|
|
@@ -2,10 +2,17 @@ from sentence_transformers import SentenceTransformer |
|
|
|
from ..table import RKMETableSpecification |
|
|
|
import numpy as np |
|
|
|
import os |
|
|
|
import langdetect |
|
|
|
from ....logger import get_module_logger |
|
|
|
|
|
|
|
logger = get_module_logger("RKMETextStatSpecification", "INFO") |
|
|
|
|
|
|
|
|
|
|
|
class RKMETextStatSpecification(RKMETableSpecification): |
|
|
|
"""Reduced Kernel Mean Embedding (RKME) Specification for Text""" |
|
|
|
def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): |
|
|
|
RKMETableSpecification.__init__(self, gamma, cuda_idx) |
|
|
|
self.language = [] |
|
|
|
|
|
|
|
def generate_stat_spec_from_data( |
|
|
|
self, |
|
|
|
@@ -35,6 +42,8 @@ class RKMETextStatSpecification(RKMETableSpecification): |
|
|
|
""" |
|
|
|
|
|
|
|
# Sentence embedding for Text |
|
|
|
self.language = self.get_language_ids(X) |
|
|
|
logger.info("The text learnware's language: %s" % (self.language)) |
|
|
|
X = self.get_sentence_embedding(X) |
|
|
|
|
|
|
|
# Generate specification |
|
|
|
@@ -47,12 +56,22 @@ class RKMETextStatSpecification(RKMETableSpecification): |
|
|
|
reduce, |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_language_ids(X): |
|
|
|
try: |
|
|
|
text = ' '.join(X) |
|
|
|
lang = langdetect.detect(text) |
|
|
|
langs = langdetect.detect_langs(text) |
|
|
|
return [l.lang for l in langs] |
|
|
|
except Exception as e: |
|
|
|
logger.warning("Language detection failed.") |
|
|
|
return [] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_sentence_embedding(X): |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
|
X = model.encode(X) |
|
|
|
X = np.array(X) |
|
|
|
# print(X.shape, np.mean(X, axis=1).shape, np.std(X, axis=1).reshape(X.shape[0], 1).shape) |
|
|
|
# X = (X - np.mean(X, axis=1).reshape(X.shape[0], 1)) / np.std(X, axis=1).reshape(X.shape[0], 1) |
|
|
|
# X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) |
|
|
|
return X |