From 5ea690d7439c3229ee5f6f8d7af7cf329c41fa18 Mon Sep 17 00:00:00 2001 From: "yingda.chen" Date: Tue, 9 Aug 2022 17:01:53 +0800 Subject: [PATCH] [to #42322933]split text generation tests Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9684735 * split test generation tests --- tests/pipelines/test_text_generation.py | 82 +++++++++++++++++-------- 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index c391e0a1..cebc80ae 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -34,6 +34,61 @@ class TextGenerationTest(unittest.TestCase): self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' self.gpt3_input = '我很好奇' + def run_pipeline_with_model_instance(self, model_id, input): + model = Model.from_pretrained(model_id) + preprocessor = TextGenerationPreprocessor( + model.model_dir, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.text_generation, model=model, preprocessor=preprocessor) + print(pipeline_ins(input)) + + def run_pipeline_with_model_id(self, model_id, input): + pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) + print(pipeline_ins(input)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_palm_zh_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_palm_en_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_en, + self.palm_input_en) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_gpt_base_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_base_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_gpt_large_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_large_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_en_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_en, + self.palm_input_en) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_base_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_large_model_id, + self.gpt3_input) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), @@ -68,33 +123,6 @@ class TextGenerationTest(unittest.TestCase): f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' ) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_run_with_model_from_modelhub(self): - for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), - (self.palm_model_id_en, self.palm_input_en), - (self.gpt3_base_model_id, self.gpt3_input), - (self.gpt3_large_model_id, self.gpt3_input)): - model = Model.from_pretrained(model_id) - preprocessor = TextGenerationPreprocessor( - model.model_dir, - model.tokenizer, - first_sequence='sentence', - second_sequence=None) - pipeline_ins = pipeline( - task=Tasks.text_generation, - model=model, - preprocessor=preprocessor) - print(pipeline_ins(input)) - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_name(self): - for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), - (self.palm_model_id_en, self.palm_input_en), - (self.gpt3_base_model_id, self.gpt3_input), - (self.gpt3_large_model_id, self.gpt3_input)): - pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) - print(pipeline_ins(input)) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.text_generation)