diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index bf240bbd..52181274 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -238,7 +238,16 @@ class TokenClassificationTransformersPreprocessor( is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg( 'is_split_into_words', False) if is_split_into_words: - tokens = list(tokens) + # for supporting prompt seperator, should split twice. [SEP] for default. + sep_idx = tokens.find('[SEP]') + if sep_idx == -1 or self.is_lstm_model: + tokens = list(tokens) + else: + tmp_tokens = [] + tmp_tokens.extend(list(tokens[:sep_idx])) + tmp_tokens.append('[SEP]') + tmp_tokens.extend(list(tokens[sep_idx + 5:])) + tokens = tmp_tokens if is_split_into_words and self.mode == ModeKeys.INFERENCE: encodings, word_ids = self._tokenize_text_by_words( @@ -250,6 +259,16 @@ class TokenClassificationTransformersPreprocessor( encodings, word_ids = self._tokenize_text_with_slow_tokenizer( tokens, **kwargs) + # modify label mask, mask all prompt tokens (tokens after sep token) + sep_idx = -1 + for idx, token_id in enumerate(encodings['input_ids']): + if token_id == self.nlp_tokenizer.tokenizer.sep_token_id: + sep_idx = idx + break + if sep_idx != -1: + for i in range(sep_idx, len(encodings['label_mask'])): + encodings['label_mask'][i] = False + if self.mode == ModeKeys.INFERENCE: for key in encodings.keys(): encodings[key] = torch.tensor(encodings[key]).unsqueeze(0) diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index abc6634a..01a00f2a 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -262,7 +262,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): self.lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' self.addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base' self.lstm_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-generic' - self.sentence = '这与温岭市新河镇的一个神秘的传说有关。' + self.sentence = '这与温岭市新河镇的一个神秘的传说有关。[SEP]地名' self.sentence_en = 'pizza shovel' self.sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' self.addr = '浙江省杭州市余杭区文一西路969号亲橙里'