From 49b6b475d2afb4dfe6e66b72099cd950cba91a9f Mon Sep 17 00:00:00 2001 From: nju-xy <1582857295@qq.com> Date: Wed, 1 Nov 2023 19:55:53 +0800 Subject: [PATCH] [MNT] language settings for text --- learnware/market/easy2/searcher.py | 7 ++++-- learnware/reuse/job_selector.py | 2 +- learnware/specification/regular/text/rkme.py | 23 ++++++++++++++++++-- setup.py | 1 + 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 0308ae8..5ef4ae3 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -438,6 +438,9 @@ class EasyStatSearcher(BaseSearcher): if self.stat_info_name not in learnware.specification.stat_spec: continue rkme = learnware.specification.get_stat_spec_by_name(self.stat_info_name) + if self.stat_info_name == "RKMETextStatSpecification": + if not set(user_rkme.language).issubset(set(rkme.language)): + continue rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -632,8 +635,8 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) elif "RKMETextStatSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 46c8951..ee58399 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -95,7 +95,7 @@ class JobSelectorReuser(BaseReuser): if isinstance(user_data[0], str): user_data = RKMETextStatSpecification.get_sentence_embedding(user_data) spec_name = "RKMETableSpecification" - if "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: + if len(self.learnware_list) and "RKMETextStatSpecification" in self.learnware_list[0].specification.stat_spec: spec_name = "RKMETextStatSpecification" learnware_rkme_spec_list = [ learnware.specification.get_stat_spec_by_name(spec_name) for learnware in self.learnware_list diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index 1483fb4..7691bb9 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -2,10 +2,17 @@ from sentence_transformers import SentenceTransformer from ..table import RKMETableSpecification import numpy as np import os +import langdetect +from ....logger import get_module_logger + +logger = get_module_logger("RKMETextStatSpecification", "INFO") class RKMETextStatSpecification(RKMETableSpecification): """Reduced Kernel Mean Embedding (RKME) Specification for Text""" + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): + RKMETableSpecification.__init__(self, gamma, cuda_idx) + self.language = [] def generate_stat_spec_from_data( self, @@ -35,6 +42,8 @@ class RKMETextStatSpecification(RKMETableSpecification): """ # Sentence embedding for Text + self.language = self.get_language_ids(X) + logger.info("The text learnware's language: %s" % (self.language)) X = self.get_sentence_embedding(X) # Generate specification @@ -47,12 +56,22 @@ class RKMETextStatSpecification(RKMETableSpecification): reduce, ) + @staticmethod + def get_language_ids(X): + try: + text = ' '.join(X) + lang = langdetect.detect(text) + langs = langdetect.detect_langs(text) + return [l.lang for l in langs] + except Exception as e: + logger.warning("Language detection failed.") + return [] + @staticmethod def get_sentence_embedding(X): os.environ["TOKENIZERS_PARALLELISM"] = "false" model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") X = model.encode(X) X = np.array(X) - # print(X.shape, np.mean(X, axis=1).shape, np.std(X, axis=1).reshape(X.shape[0], 1).shape) - # X = (X - np.mean(X, axis=1).reshape(X.shape[0], 1)) / np.std(X, axis=1).reshape(X.shape[0], 1) + # X /= np.sqrt(np.sum(X ** 2, axis=1)).reshape((-1, 1)) return X diff --git a/setup.py b/setup.py index 8119fd9..1e8ea72 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ REQUIRED = [ "torchtext>=0.16.0", "sentence_transformers>=2.2.2", "torch-optimizer>=0.3.0", + "langdetect>=1.0.9", ] if get_platform() != MACOS: