From 4c6eea16e033a159ef3a0374726cb40d658f937c Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Sat, 18 Nov 2023 19:36:36 +0800 Subject: [PATCH 1/2] [MNT] Add choice: download sentence emb model from market --- learnware/specification/regular/text/rkme.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index a452810..2f5b63a 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer from ..table import RKMETableSpecification from ....logger import get_module_logger +from ....config import C logger = get_module_logger("RKMETextSpecification", "INFO") @@ -71,9 +72,18 @@ class RKMETextSpecification(RKMETableSpecification): @staticmethod def get_sentence_embedding(X): + from ....client import LearnwareClient os.environ["TOKENIZERS_PARALLELISM"] = "false" - model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") - X = model.encode(X) + cache_dir = C["cache_path"] + try: + model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir) + X = model.encode(X) + except: + client = LearnwareClient() + zip_path = os.path.join(cache_dir, "MiniLM.zip") + if not os.path.exists(zip_path): + client.download_learnware("00000662", zip_path) + miniLM_learnware = client.load_learnware(zip_path) + X = miniLM_learnware.predict(X) X = np.array(X) - # X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) return X From 7425009be613e5f88ed49a5e28cffaaf90f48c0f Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 19 Nov 2023 00:22:45 +0800 Subject: [PATCH 2/2] [MNT] add logger and modify details --- learnware/specification/regular/text/rkme.py | 30 ++++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 2f5b63a..59f80ef 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -72,18 +72,30 @@ class RKMETextSpecification(RKMETableSpecification): @staticmethod def get_sentence_embedding(X): - from ....client import LearnwareClient os.environ["TOKENIZERS_PARALLELISM"] = "false" - cache_dir = C["cache_path"] - try: - model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir) - X = model.encode(X) - except: + cache_dir = C.cache_path + zip_path = os.path.join(cache_dir, "MiniLM.zip") + + def _get_from_client(zip_path, X): + from ....client import LearnwareClient + client = LearnwareClient() - zip_path = os.path.join(cache_dir, "MiniLM.zip") if not os.path.exists(zip_path): + logger.info("Download the necessary feature extractor from Beimingwu system.") client.download_learnware("00000662", zip_path) miniLM_learnware = client.load_learnware(zip_path) - X = miniLM_learnware.predict(X) - X = np.array(X) + return np.array(miniLM_learnware.predict(X)) + + logger.info("Load the necessary feature extractor for RKMETextSpecification.") + if os.path.exists(zip_path): + X = _get_from_client(zip_path, X) + else: + try: + model = SentenceTransformer( + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir + ) + X = np.array(model.encode(X)) + except: + X = _get_from_client(zip_path, X) + return X