From 06abae4dc6d68e99cba56608c857de5cdabd16b0 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Tue, 1 Nov 2022 09:56:15 +0800 Subject: [PATCH] [to #42322933]add token-cls test cases and bug fix Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10585502 --- .../nlp/token_classification_preprocessor.py | 3 +-- tests/pipelines/test_named_entity_recognition.py | 8 ++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index 5069048b..92b7c46b 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -140,8 +140,7 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): label_mask.append(1) offset_mapping.append(encodings['offset_mapping'][i]) else: - encodings = self.tokenizer( - text, add_special_tokens=False, **self.tokenize_kwargs) + encodings = self.tokenizer(text, **self.tokenize_kwargs) input_ids = encodings['input_ids'] label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( text) diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index 3658cf3f..aef4aaed 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.named_entity_recognition self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' sentence = '这与温岭市新河镇的一个神秘的传说有关。' + sentence_en = 'pizza shovel' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download(self): @@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): task=Tasks.named_entity_recognition, model=self.lcrf_model_id) print(pipeline_ins(input=self.sentence)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_english_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.english_model_id) + print(pipeline_ins(input='pizza shovel')) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.named_entity_recognition)