|
|
|
@@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
from ..table import RKMETableSpecification |
|
|
|
from ....logger import get_module_logger |
|
|
|
from ....config import C |
|
|
|
|
|
|
|
logger = get_module_logger("RKMETextSpecification", "INFO") |
|
|
|
|
|
|
|
@@ -72,8 +73,29 @@ class RKMETextSpecification(RKMETableSpecification): |
|
|
|
@staticmethod |
|
|
|
def get_sentence_embedding(X): |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
|
X = model.encode(X) |
|
|
|
X = np.array(X) |
|
|
|
# X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) |
|
|
|
cache_dir = C.cache_path |
|
|
|
zip_path = os.path.join(cache_dir, "MiniLM.zip") |
|
|
|
|
|
|
|
def _get_from_client(zip_path, X): |
|
|
|
from ....client import LearnwareClient |
|
|
|
|
|
|
|
client = LearnwareClient() |
|
|
|
if not os.path.exists(zip_path): |
|
|
|
logger.info("Download the necessary feature extractor from Beimingwu system.") |
|
|
|
client.download_learnware("00000662", zip_path) |
|
|
|
miniLM_learnware = client.load_learnware(zip_path) |
|
|
|
return np.array(miniLM_learnware.predict(X)) |
|
|
|
|
|
|
|
logger.info("Load the necessary feature extractor for RKMETextSpecification.") |
|
|
|
if os.path.exists(zip_path): |
|
|
|
X = _get_from_client(zip_path, X) |
|
|
|
else: |
|
|
|
try: |
|
|
|
model = SentenceTransformer( |
|
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir |
|
|
|
) |
|
|
|
X = np.array(model.encode(X)) |
|
|
|
except: |
|
|
|
X = _get_from_client(zip_path, X) |
|
|
|
|
|
|
|
return X |