| @@ -417,14 +417,12 @@ class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| tokenizer=None, | tokenizer=None, | ||||
| mode=ModeKeys.INFERENCE, | mode=ModeKeys.INFERENCE, | ||||
| **kwargs): | **kwargs): | ||||
| self.tokenizer = self.build_tokenizer( | |||||
| model_dir) if tokenizer is None else tokenizer | |||||
| kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') | kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') | ||||
| kwargs['padding'] = kwargs.get('padding', False) | kwargs['padding'] = kwargs.get('padding', False) | ||||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | ||||
| False) | False) | ||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | kwargs['max_length'] = kwargs.pop('sequence_length', 128) | ||||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||||
| super().__init__(model_dir, mode=mode, **kwargs) | |||||
| def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: | def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: | ||||
| text_a, _, _ = self.parse_text_and_label(data) | text_a, _, _ = self.parse_text_and_label(data) | ||||
| @@ -18,7 +18,7 @@ class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.model_id = 'damo/t5-cn-base-test' | self.model_id = 'damo/t5-cn-base-test' | ||||
| self.input = '中国的首都位于<extra_id_0>。' | self.input = '中国的首都位于<extra_id_0>。' | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_T5(self): | def test_run_T5(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| model = T5ForConditionalGeneration(cache_path) | model = T5ForConditionalGeneration(cache_path) | ||||
| @@ -40,7 +40,7 @@ class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| print(pipeline_ins(self.input)) | print(pipeline_ins(self.input)) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_pipeline_with_model_id(self): | def test_run_pipeline_with_model_id(self): | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.text2text_generation, model=self.model_id) | task=Tasks.text2text_generation, model=self.model_id) | ||||