Browse Source

Merge pull request #81 from Learnware-LAMDA/xiey/dev

Add choice: download sentence emb model from market if huggingface is not accessible
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
a87dd0c01c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 4 deletions
  1. +26
    -4
      learnware/specification/regular/text/rkme.py

+ 26
- 4
learnware/specification/regular/text/rkme.py View File

@@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer

from ..table import RKMETableSpecification
from ....logger import get_module_logger
from ....config import C

logger = get_module_logger("RKMETextSpecification", "INFO")

@@ -72,8 +73,29 @@ class RKMETextSpecification(RKMETableSpecification):
@staticmethod
def get_sentence_embedding(X):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
X = model.encode(X)
X = np.array(X)
# X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1))
cache_dir = C.cache_path
zip_path = os.path.join(cache_dir, "MiniLM.zip")

def _get_from_client(zip_path, X):
from ....client import LearnwareClient

client = LearnwareClient()
if not os.path.exists(zip_path):
logger.info("Download the necessary feature extractor from Beimingwu system.")
client.download_learnware("00000662", zip_path)
miniLM_learnware = client.load_learnware(zip_path)
return np.array(miniLM_learnware.predict(X))

logger.info("Load the necessary feature extractor for RKMETextSpecification.")
if os.path.exists(zip_path):
X = _get_from_client(zip_path, X)
else:
try:
model = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", cache_folder=cache_dir
)
X = np.array(model.encode(X))
except:
X = _get_from_client(zip_path, X)

return X

Loading…
Cancel
Save