Browse Source

support prompt ner

修改preprocessor增加对prompt模型的支持。
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10972542
master^2
xuanjie.wxb yingda.chen 3 years ago
parent
commit
a3a942352e
2 changed files with 21 additions and 2 deletions
  1. +20
    -1
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  2. +1
    -1
      tests/pipelines/test_named_entity_recognition.py

+ 20
- 1
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -238,7 +238,16 @@ class TokenClassificationTransformersPreprocessor(
is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg( is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg(
'is_split_into_words', False) 'is_split_into_words', False)
if is_split_into_words: if is_split_into_words:
tokens = list(tokens)
# for supporting prompt seperator, should split twice. [SEP] for default.
sep_idx = tokens.find('[SEP]')
if sep_idx == -1 or self.is_lstm_model:
tokens = list(tokens)
else:
tmp_tokens = []
tmp_tokens.extend(list(tokens[:sep_idx]))
tmp_tokens.append('[SEP]')
tmp_tokens.extend(list(tokens[sep_idx + 5:]))
tokens = tmp_tokens


if is_split_into_words and self.mode == ModeKeys.INFERENCE: if is_split_into_words and self.mode == ModeKeys.INFERENCE:
encodings, word_ids = self._tokenize_text_by_words( encodings, word_ids = self._tokenize_text_by_words(
@@ -250,6 +259,16 @@ class TokenClassificationTransformersPreprocessor(
encodings, word_ids = self._tokenize_text_with_slow_tokenizer( encodings, word_ids = self._tokenize_text_with_slow_tokenizer(
tokens, **kwargs) tokens, **kwargs)


# modify label mask, mask all prompt tokens (tokens after sep token)
sep_idx = -1
for idx, token_id in enumerate(encodings['input_ids']):
if token_id == self.nlp_tokenizer.tokenizer.sep_token_id:
sep_idx = idx
break
if sep_idx != -1:
for i in range(sep_idx, len(encodings['label_mask'])):
encodings['label_mask'][i] = False

if self.mode == ModeKeys.INFERENCE: if self.mode == ModeKeys.INFERENCE:
for key in encodings.keys(): for key in encodings.keys():
encodings[key] = torch.tensor(encodings[key]).unsqueeze(0) encodings[key] = torch.tensor(encodings[key]).unsqueeze(0)


+ 1
- 1
tests/pipelines/test_named_entity_recognition.py View File

@@ -262,7 +262,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' self.lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
self.addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base' self.addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base'
self.lstm_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-generic' self.lstm_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-generic'
self.sentence = '这与温岭市新河镇的一个神秘的传说有关。'
self.sentence = '这与温岭市新河镇的一个神秘的传说有关。[SEP]地名'
self.sentence_en = 'pizza shovel' self.sentence_en = 'pizza shovel'
self.sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' self.sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。'
self.addr = '浙江省杭州市余杭区文一西路969号亲橙里' self.addr = '浙江省杭州市余杭区文一西路969号亲橙里'


Loading…
Cancel
Save