Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10608664master
| @@ -73,10 +73,12 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| super().__init__(model_dir, mode=mode, **kwargs) | super().__init__(model_dir, mode=mode, **kwargs) | ||||
| if 'is_split_into_words' in kwargs: | if 'is_split_into_words' in kwargs: | ||||
| self.is_split_into_words = kwargs.pop('is_split_into_words') | |||||
| self.tokenize_kwargs['is_split_into_words'] = kwargs.pop( | |||||
| 'is_split_into_words') | |||||
| else: | else: | ||||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||||
| 'is_split_into_words', False) | |||||
| self.tokenize_kwargs[ | |||||
| 'is_split_into_words'] = self.tokenizer.init_kwargs.get( | |||||
| 'is_split_into_words', False) | |||||
| if 'label2id' in kwargs: | if 'label2id' in kwargs: | ||||
| kwargs.pop('label2id') | kwargs.pop('label2id') | ||||
| @@ -99,7 +101,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| if isinstance(data, str): | if isinstance(data, str): | ||||
| # for inference inputs without label | # for inference inputs without label | ||||
| text = data | text = data | ||||
| self.tokenize_kwargs['add_special_tokens'] = False | |||||
| elif isinstance(data, dict): | elif isinstance(data, dict): | ||||
| # for finetune inputs with label | # for finetune inputs with label | ||||
| text = data.get(self.first_sequence) | text = data.get(self.first_sequence) | ||||
| @@ -107,11 +108,15 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| if isinstance(text, list): | if isinstance(text, list): | ||||
| self.tokenize_kwargs['is_split_into_words'] = True | self.tokenize_kwargs['is_split_into_words'] = True | ||||
| if self._mode == ModeKeys.INFERENCE: | |||||
| self.tokenize_kwargs['add_special_tokens'] = False | |||||
| input_ids = [] | input_ids = [] | ||||
| label_mask = [] | label_mask = [] | ||||
| offset_mapping = [] | offset_mapping = [] | ||||
| token_type_ids = [] | token_type_ids = [] | ||||
| if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||||
| if self.tokenize_kwargs[ | |||||
| 'is_split_into_words'] and self._mode == ModeKeys.INFERENCE: | |||||
| for offset, token in enumerate(list(text)): | for offset, token in enumerate(list(text)): | ||||
| subtoken_ids = self.tokenizer.encode(token, | subtoken_ids = self.tokenizer.encode(token, | ||||
| **self.tokenize_kwargs) | **self.tokenize_kwargs) | ||||
| @@ -125,7 +130,8 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| encodings = self.tokenizer( | encodings = self.tokenizer( | ||||
| text, return_offsets_mapping=True, **self.tokenize_kwargs) | text, return_offsets_mapping=True, **self.tokenize_kwargs) | ||||
| attention_mask = encodings['attention_mask'] | attention_mask = encodings['attention_mask'] | ||||
| token_type_ids = encodings['token_type_ids'] | |||||
| if 'token_type_ids' in encodings: | |||||
| token_type_ids = encodings['token_type_ids'] | |||||
| input_ids = encodings['input_ids'] | input_ids = encodings['input_ids'] | ||||
| word_ids = encodings.word_ids() | word_ids = encodings.word_ids() | ||||
| for i in range(len(word_ids)): | for i in range(len(word_ids)): | ||||
| @@ -27,6 +27,9 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, | |||||
| viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' | viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' | ||||
| viet_sentence = 'Nón vành dễ thương cho bé gái' | viet_sentence = 'Nón vành dễ thương cho bé gái' | ||||
| multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic' | |||||
| ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_tcrf_by_direct_model_download_thai(self): | def test_run_tcrf_by_direct_model_download_thai(self): | ||||
| cache_path = snapshot_download(self.thai_tcrf_model_id) | cache_path = snapshot_download(self.thai_tcrf_model_id) | ||||
| @@ -60,6 +63,13 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, | |||||
| task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) | task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) | ||||
| print(pipeline_ins(input=self.thai_sentence)) | print(pipeline_ins(input=self.thai_sentence)) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_tcrf_with_model_name_multilingual(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.named_entity_recognition, | |||||
| model=self.multilingual_model_id) | |||||
| print(pipeline_ins(input=self.ml_stc)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_tcrf_by_direct_model_download_viet(self): | def test_run_tcrf_by_direct_model_download_viet(self): | ||||
| cache_path = snapshot_download(self.viet_tcrf_model_id) | cache_path = snapshot_download(self.viet_tcrf_model_id) | ||||
| @@ -20,10 +20,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | ||||
| english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | ||||
| chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' | |||||
| tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | ||||
| lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | ||||
| sentence = '这与温岭市新河镇的一个神秘的传说有关。' | sentence = '这与温岭市新河镇的一个神秘的传说有关。' | ||||
| sentence_en = 'pizza shovel' | sentence_en = 'pizza shovel' | ||||
| sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_tcrf_by_direct_model_download(self): | def test_run_tcrf_by_direct_model_download(self): | ||||
| @@ -91,11 +93,17 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | ||||
| print(pipeline_ins(input=self.sentence)) | print(pipeline_ins(input=self.sentence)) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_lcrf_with_chinese_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.named_entity_recognition, model=self.chinese_model_id) | |||||
| print(pipeline_ins(input=self.sentence_zh)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_english_with_model_name(self): | def test_run_english_with_model_name(self): | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.named_entity_recognition, model=self.english_model_id) | task=Tasks.named_entity_recognition, model=self.english_model_id) | ||||
| print(pipeline_ins(input='pizza shovel')) | |||||
| print(pipeline_ins(input=self.sentence_en)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||