zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
d3519bcbca
3 changed files with 31 additions and 7 deletions
  1. +12
    -6
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  2. +10
    -0
      tests/pipelines/test_multilingual_named_entity_recognition.py
  3. +9
    -1
      tests/pipelines/test_named_entity_recognition.py

+ 12
- 6
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -73,10 +73,12 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
super().__init__(model_dir, mode=mode, **kwargs)

if 'is_split_into_words' in kwargs:
self.is_split_into_words = kwargs.pop('is_split_into_words')
self.tokenize_kwargs['is_split_into_words'] = kwargs.pop(
'is_split_into_words')
else:
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
self.tokenize_kwargs[
'is_split_into_words'] = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')

@@ -99,7 +101,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
if isinstance(data, str):
# for inference inputs without label
text = data
self.tokenize_kwargs['add_special_tokens'] = False
elif isinstance(data, dict):
# for finetune inputs with label
text = data.get(self.first_sequence)
@@ -107,11 +108,15 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
if isinstance(text, list):
self.tokenize_kwargs['is_split_into_words'] = True

if self._mode == ModeKeys.INFERENCE:
self.tokenize_kwargs['add_special_tokens'] = False

input_ids = []
label_mask = []
offset_mapping = []
token_type_ids = []
if self.is_split_into_words and self._mode == ModeKeys.INFERENCE:
if self.tokenize_kwargs[
'is_split_into_words'] and self._mode == ModeKeys.INFERENCE:
for offset, token in enumerate(list(text)):
subtoken_ids = self.tokenizer.encode(token,
**self.tokenize_kwargs)
@@ -125,7 +130,8 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
encodings = self.tokenizer(
text, return_offsets_mapping=True, **self.tokenize_kwargs)
attention_mask = encodings['attention_mask']
token_type_ids = encodings['token_type_ids']
if 'token_type_ids' in encodings:
token_type_ids = encodings['token_type_ids']
input_ids = encodings['input_ids']
word_ids = encodings.word_ids()
for i in range(len(word_ids)):


+ 10
- 0
tests/pipelines/test_multilingual_named_entity_recognition.py View File

@@ -27,6 +27,9 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase,
viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title'
viet_sentence = 'Nón vành dễ thương cho bé gái'

multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic'
ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download_thai(self):
cache_path = snapshot_download(self.thai_tcrf_model_id)
@@ -60,6 +63,13 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase,
task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id)
print(pipeline_ins(input=self.thai_sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_with_model_name_multilingual(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
model=self.multilingual_model_id)
print(pipeline_ins(input=self.ml_stc))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download_viet(self):
cache_path = snapshot_download(self.viet_tcrf_model_id)


+ 9
- 1
tests/pipelines/test_named_entity_recognition.py View File

@@ -20,10 +20,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'

english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom'
chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'
sentence_en = 'pizza shovel'
sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
@@ -91,11 +93,17 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_with_chinese_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.chinese_model_id)
print(pipeline_ins(input=self.sentence_zh))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_english_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.english_model_id)
print(pipeline_ins(input='pizza shovel'))
print(pipeline_ins(input=self.sentence_en))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):


Loading…
Cancel
Save