Browse Source

[to #42322933] Add palm ut

为以下三个模型补充 ut
damo/nlp_palm2.0_text-generation_chinese-large
damo/nlp_palm2.0_text-generation_commodity_chinese-base
damo/nlp_palm2.0_text-generation_weather_chinese-base
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10435599
master
hemu.zp yingda.chen 3 years ago
parent
commit
2b49b322a2
1 changed files with 43 additions and 7 deletions
  1. +43
    -7
      tests/pipelines/test_text_generation.py

+ 43
- 7
tests/pipelines/test_text_generation.py View File

@@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level
class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large'
self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base'
self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base'
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
self.palm_input_zh = """
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
"""
self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮'
self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'"
self.palm_input_en = """
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
@@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh,
def test_palm_zh_base_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_base,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
def test_palm_zh_large_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_base,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_en_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_en,
@@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en)):
for model_id, input in ((self.palm_model_id_zh_base,
self.palm_input_zh), (self.palm_model_id_en,
self.palm_input_en)):
cache_path = snapshot_download(model_id)
model = PalmForTextGeneration.from_pretrained(cache_path)
preprocessor = TextGenerationPreprocessor(


Loading…
Cancel
Save