From d3519bcbca98c0fdf290966ff29d08e6d3698900 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Tue, 8 Nov 2022 15:42:08 +0800 Subject: [PATCH] [to #42322933]token preprocess bug fix Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10608664 --- .../nlp/token_classification_preprocessor.py | 18 ++++++++++++------ ...st_multilingual_named_entity_recognition.py | 10 ++++++++++ .../pipelines/test_named_entity_recognition.py | 10 +++++++++- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index 92b7c46b..a7616736 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -73,10 +73,12 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): super().__init__(model_dir, mode=mode, **kwargs) if 'is_split_into_words' in kwargs: - self.is_split_into_words = kwargs.pop('is_split_into_words') + self.tokenize_kwargs['is_split_into_words'] = kwargs.pop( + 'is_split_into_words') else: - self.is_split_into_words = self.tokenizer.init_kwargs.get( - 'is_split_into_words', False) + self.tokenize_kwargs[ + 'is_split_into_words'] = self.tokenizer.init_kwargs.get( + 'is_split_into_words', False) if 'label2id' in kwargs: kwargs.pop('label2id') @@ -99,7 +101,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): if isinstance(data, str): # for inference inputs without label text = data - self.tokenize_kwargs['add_special_tokens'] = False elif isinstance(data, dict): # for finetune inputs with label text = data.get(self.first_sequence) @@ -107,11 +108,15 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): if isinstance(text, list): self.tokenize_kwargs['is_split_into_words'] = True + if self._mode == ModeKeys.INFERENCE: + self.tokenize_kwargs['add_special_tokens'] = False + input_ids = [] label_mask = [] offset_mapping = [] token_type_ids = [] - if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: + if self.tokenize_kwargs[ + 'is_split_into_words'] and self._mode == ModeKeys.INFERENCE: for offset, token in enumerate(list(text)): subtoken_ids = self.tokenizer.encode(token, **self.tokenize_kwargs) @@ -125,7 +130,8 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): encodings = self.tokenizer( text, return_offsets_mapping=True, **self.tokenize_kwargs) attention_mask = encodings['attention_mask'] - token_type_ids = encodings['token_type_ids'] + if 'token_type_ids' in encodings: + token_type_ids = encodings['token_type_ids'] input_ids = encodings['input_ids'] word_ids = encodings.word_ids() for i in range(len(word_ids)): diff --git a/tests/pipelines/test_multilingual_named_entity_recognition.py b/tests/pipelines/test_multilingual_named_entity_recognition.py index 6f72c83c..cb2b32d6 100644 --- a/tests/pipelines/test_multilingual_named_entity_recognition.py +++ b/tests/pipelines/test_multilingual_named_entity_recognition.py @@ -27,6 +27,9 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' viet_sentence = 'Nón vành dễ thương cho bé gái' + multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic' + ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।' + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download_thai(self): cache_path = snapshot_download(self.thai_tcrf_model_id) @@ -60,6 +63,13 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) print(pipeline_ins(input=self.thai_sentence)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_with_model_name_multilingual(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=self.multilingual_model_id) + print(pipeline_ins(input=self.ml_stc)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download_viet(self): cache_path = snapshot_download(self.viet_tcrf_model_id) diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index aef4aaed..0df44f5b 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -20,10 +20,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' + chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' sentence = '这与温岭市新河镇的一个神秘的传说有关。' sentence_en = 'pizza shovel' + sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download(self): @@ -91,11 +93,17 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): task=Tasks.named_entity_recognition, model=self.lcrf_model_id) print(pipeline_ins(input=self.sentence)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_chinese_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.chinese_model_id) + print(pipeline_ins(input=self.sentence_zh)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_english_with_model_name(self): pipeline_ins = pipeline( task=Tasks.named_entity_recognition, model=self.english_model_id) - print(pipeline_ins(input='pizza shovel')) + print(pipeline_ins(input=self.sentence_en)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self):