From 2b49b322a2b452b96413fe70c678c78be7b5b61a Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 20 Oct 2022 19:50:40 +0800 Subject: [PATCH] [to #42322933] Add palm ut MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为以下三个模型补充 ut damo/nlp_palm2.0_text-generation_chinese-large damo/nlp_palm2.0_text-generation_commodity_chinese-base damo/nlp_palm2.0_text-generation_weather_chinese-base Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10435599 --- tests/pipelines/test_text_generation.py | 50 +++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 5a270f83..4b0ebd47 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large' + self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base' + self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base' self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' self.palm_input_zh = """ 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 """ + self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮' + self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'" self.palm_input_en = """ The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , @@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_palm_zh_with_model_name(self): - self.run_pipeline_with_model_id(self.palm_model_id_zh, + def test_palm_zh_base_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_base, self.palm_input_zh) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.gpt3_input) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_palm_zh_with_model_instance(self): - self.run_pipeline_with_model_instance(self.palm_model_id_zh, + def test_palm_zh_large_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_weather, + self.palm_input_weather) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_base, self.palm_input_zh) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather, + self.palm_input_weather) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_palm_en_with_model_instance(self): self.run_pipeline_with_model_instance(self.palm_model_id_en, @@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): - for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), - (self.palm_model_id_en, self.palm_input_en)): + for model_id, input in ((self.palm_model_id_zh_base, + self.palm_input_zh), (self.palm_model_id_en, + self.palm_input_en)): cache_path = snapshot_download(model_id) model = PalmForTextGeneration.from_pretrained(cache_path) preprocessor = TextGenerationPreprocessor(