Browse Source

[MNT] language settings for text

tags/v0.3.2
nju-xy 2 years ago
parent
commit
49b6b475d2
4 changed files with 28 additions and 5 deletions
  1. +5
    -2
      learnware/market/easy2/searcher.py
  2. +1
    -1
      learnware/reuse/job_selector.py
  3. +21
    -2
      learnware/specification/regular/text/rkme.py
  4. +1
    -0
      setup.py

+ 5
- 2
learnware/market/easy2/searcher.py View File

@@ -438,6 +438,9 @@ class EasyStatSearcher(BaseSearcher):
if self.stat_info_name not in learnware.specification.stat_spec:
continue
rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name)
if self.stat_info_name == "RKMETextStatSpecification":
if not set(user_rkme.language).issubset(set(rkme.language)):
continue
rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware)
@@ -632,8 +635,8 @@ class EasySearcher(BaseSearcher):
if len(learnware_list) == 0:
return [], [], 0.0, []
elif "RKMETableSpecification" in user_info.stat_info:
return self.table_searcher(learnware_list, user_info, max_search_num, search_method)
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
elif "RKMETextStatSpecification" in user_info.stat_info:
return self.table_searcher(learnware_list, user_info, max_search_num, search_method)
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
else:
return None, learnware_list, 0.0, None

+ 1
- 1
learnware/reuse/job_selector.py View File

@@ -95,7 +95,7 @@ class JobSelectorReuser(BaseReuser):
if isinstance(user_data[0], str):
user_data = RKMETextStatSpecification.get_sentence_embedding(user_data)
spec_name = "RKMETableSpecification"
if "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec:
if len(self.learnware_list) and "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec:
spec_name = "RKMETextStatSpecification"
learnware_rkme_spec_list = [
learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list


+ 21
- 2
learnware/specification/regular/text/rkme.py View File

@@ -2,10 +2,17 @@ from sentence_transformers import SentenceTransformer
from ..table import RKMETableSpecification
import numpy as np
import os
import langdetect
from ....logger import get_module_logger

logger = get_module_logger("RKMETextStatSpecification", "INFO")


class RKMETextStatSpecification(RKMETableSpecification):
"""Reduced Kernel Mean Embedding (RKME) Specification for Text"""
def __init__(self, gamma: float = 0.1, cuda_idx: int = -1):
RKMETableSpecification.__init__(self, gamma, cuda_idx)
self.language = []

def generate_stat_spec_from_data(
self,
@@ -35,6 +42,8 @@ class RKMETextStatSpecification(RKMETableSpecification):
"""

# Sentence embedding for Text
self.language = self.get_language_ids(X)
logger.info("The text learnware's language: %s" % (self.language))
X = self.get_sentence_embedding(X)

# Generate specification
@@ -47,12 +56,22 @@ class RKMETextStatSpecification(RKMETableSpecification):
reduce,
)

@staticmethod
def get_language_ids(X):
try:
text = ' '.join(X)
lang = langdetect.detect(text)
langs = langdetect.detect_langs(text)
return [l.lang for l in langs]
except Exception as e:
logger.warning("Language detection failed.")
return []

@staticmethod
def get_sentence_embedding(X):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
X = model.encode(X)
X = np.array(X)
# print(X.shape, np.mean(X, axis=1).shape, np.std(X, axis=1).reshape(X.shape[0], 1).shape)
# X = (X - np.mean(X, axis=1).reshape(X.shape[0], 1)) / np.std(X, axis=1).reshape(X.shape[0], 1)
# X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1))
return X

+ 1
- 0
setup.py View File

@@ -73,6 +73,7 @@ REQUIRED = [
"torchtext>=0.16.0",
"sentence_transformers>=2.2.2",
"torch-optimizer>=0.3.0",
"langdetect>=1.0.9",
]

if get_platform() != MACOS:


Loading…
Cancel
Save