From a3a942352eddf0d3be23814801858c2fa93ce833 Mon Sep 17 00:00:00 2001 From: "xuanjie.wxb" Date: Tue, 6 Dec 2022 10:39:37 +0800 Subject: [PATCH] support prompt ner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改preprocessor增加对prompt模型的支持。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10972542 --- .../nlp/token_classification_preprocessor.py | 21 ++++++++++++++++++- .../test_named_entity_recognition.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index bf240bbd..52181274 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -238,7 +238,16 @@ class TokenClassificationTransformersPreprocessor( is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg( 'is_split_into_words', False) if is_split_into_words: - tokens = list(tokens) + # for supporting prompt seperator, should split twice. [SEP] for default. + sep_idx = tokens.find('[SEP]') + if sep_idx == -1 or self.is_lstm_model: + tokens = list(tokens) + else: + tmp_tokens = [] + tmp_tokens.extend(list(tokens[:sep_idx])) + tmp_tokens.append('[SEP]') + tmp_tokens.extend(list(tokens[sep_idx + 5:])) + tokens = tmp_tokens if is_split_into_words and self.mode == ModeKeys.INFERENCE: encodings, word_ids = self._tokenize_text_by_words( @@ -250,6 +259,16 @@ class TokenClassificationTransformersPreprocessor( encodings, word_ids = self._tokenize_text_with_slow_tokenizer( tokens, **kwargs) + # modify label mask, mask all prompt tokens (tokens after sep token) + sep_idx = -1 + for idx, token_id in enumerate(encodings['input_ids']): + if token_id == self.nlp_tokenizer.tokenizer.sep_token_id: + sep_idx = idx + break + if sep_idx != -1: + for i in range(sep_idx, len(encodings['label_mask'])): + encodings['label_mask'][i] = False + if self.mode == ModeKeys.INFERENCE: for key in encodings.keys(): encodings[key] = torch.tensor(encodings[key]).unsqueeze(0) diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index abc6634a..01a00f2a 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -262,7 +262,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): self.lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' self.addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base' self.lstm_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-generic' - self.sentence = '这与温岭市新河镇的一个神秘的传说有关。' + self.sentence = '这与温岭市新河镇的一个神秘的传说有关。[SEP]地名' self.sentence_en = 'pizza shovel' self.sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' self.addr = '浙江省杭州市余杭区文一西路969号亲橙里'