diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index c8ec1e4..aa7d72e 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer from ..table import RKMETableSpecification from ....logger import get_module_logger +from ....config import C logger = get_module_logger("RKMETextSpecification", "INFO") @@ -72,8 +73,29 @@ class RKMETextSpecification(RKMETableSpecification): @staticmethod def get_sentence_embedding(X): os.environ["TOKENIZERS_PARALLELISM"] = "false" - model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") - X = model.encode(X) - X = np.array(X) - # X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) + cache_dir = C.cache_path + zip_path = os.path.join(cache_dir, "MiniLM.zip") + + def _get_from_client(zip_path, X): + from ....client import LearnwareClient + + client = LearnwareClient() + if not os.path.exists(zip_path): + logger.info("Download the necessary feature extractor from Beimingwu system.") + client.download_learnware("00000662", zip_path) + miniLM_learnware = client.load_learnware(zip_path) + return np.array(miniLM_learnware.predict(X)) + + logger.info("Load the necessary feature extractor for RKMETextSpecification.") + if os.path.exists(zip_path): + X = _get_from_client(zip_path, X) + else: + try: + model = SentenceTransformer( + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir + ) + X = np.array(model.encode(X)) + except: + X = _get_from_client(zip_path, X) + return X