Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10608664master
| @@ -73,10 +73,12 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| if 'is_split_into_words' in kwargs: | |||
| self.is_split_into_words = kwargs.pop('is_split_into_words') | |||
| self.tokenize_kwargs['is_split_into_words'] = kwargs.pop( | |||
| 'is_split_into_words') | |||
| else: | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| self.tokenize_kwargs[ | |||
| 'is_split_into_words'] = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| if 'label2id' in kwargs: | |||
| kwargs.pop('label2id') | |||
| @@ -99,7 +101,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| if isinstance(data, str): | |||
| # for inference inputs without label | |||
| text = data | |||
| self.tokenize_kwargs['add_special_tokens'] = False | |||
| elif isinstance(data, dict): | |||
| # for finetune inputs with label | |||
| text = data.get(self.first_sequence) | |||
| @@ -107,11 +108,15 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| if isinstance(text, list): | |||
| self.tokenize_kwargs['is_split_into_words'] = True | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| self.tokenize_kwargs['add_special_tokens'] = False | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| token_type_ids = [] | |||
| if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||
| if self.tokenize_kwargs[ | |||
| 'is_split_into_words'] and self._mode == ModeKeys.INFERENCE: | |||
| for offset, token in enumerate(list(text)): | |||
| subtoken_ids = self.tokenizer.encode(token, | |||
| **self.tokenize_kwargs) | |||
| @@ -125,7 +130,8 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| encodings = self.tokenizer( | |||
| text, return_offsets_mapping=True, **self.tokenize_kwargs) | |||
| attention_mask = encodings['attention_mask'] | |||
| token_type_ids = encodings['token_type_ids'] | |||
| if 'token_type_ids' in encodings: | |||
| token_type_ids = encodings['token_type_ids'] | |||
| input_ids = encodings['input_ids'] | |||
| word_ids = encodings.word_ids() | |||
| for i in range(len(word_ids)): | |||
| @@ -27,6 +27,9 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, | |||
| viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' | |||
| viet_sentence = 'Nón vành dễ thương cho bé gái' | |||
| multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic' | |||
| ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download_thai(self): | |||
| cache_path = snapshot_download(self.thai_tcrf_model_id) | |||
| @@ -60,6 +63,13 @@ class MultilingualNamedEntityRecognitionTest(unittest.TestCase, | |||
| task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) | |||
| print(pipeline_ins(input=self.thai_sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_name_multilingual(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=self.multilingual_model_id) | |||
| print(pipeline_ins(input=self.ml_stc)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download_viet(self): | |||
| cache_path = snapshot_download(self.viet_tcrf_model_id) | |||
| @@ -20,10 +20,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' | |||
| chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' | |||
| tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | |||
| sentence = '这与温岭市新河镇的一个神秘的传说有关。' | |||
| sentence_en = 'pizza shovel' | |||
| sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download(self): | |||
| @@ -91,11 +93,17 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_with_chinese_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.chinese_model_id) | |||
| print(pipeline_ins(input=self.sentence_zh)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_english_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.english_model_id) | |||
| print(pipeline_ins(input='pizza shovel')) | |||
| print(pipeline_ins(input=self.sentence_en)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||