Browse Source

[to #42322933] Add palm ut

为以下三个模型补充 ut
damo/nlp_palm2.0_text-generation_chinese-large
damo/nlp_palm2.0_text-generation_commodity_chinese-base
damo/nlp_palm2.0_text-generation_weather_chinese-base
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10435599
master
hemu.zp yingda.chen 3 years ago
parent
commit
2b49b322a2
1 changed files with 43 additions and 7 deletions
  1. +43
    -7
      tests/pipelines/test_text_generation.py

+ 43
- 7
tests/pipelines/test_text_generation.py View File

@@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level
class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):


def setUp(self) -> None: def setUp(self) -> None:
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large'
self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base'
self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base'
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
self.palm_input_zh = """ self.palm_input_zh = """
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
""" """
self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮'
self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'"
self.palm_input_en = """ self.palm_input_en = """
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
@@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
print(pipeline_ins(input)) print(pipeline_ins(input))


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh,
def test_palm_zh_base_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_base,
self.palm_input_zh) self.palm_input_zh)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.gpt3_input) self.gpt3_input)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
def test_palm_zh_large_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_base,
self.palm_input_zh) self.palm_input_zh)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_en_with_model_instance(self): def test_palm_en_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_en, self.run_pipeline_with_model_instance(self.palm_model_id_en,
@@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self): def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en)):
for model_id, input in ((self.palm_model_id_zh_base,
self.palm_input_zh), (self.palm_model_id_en,
self.palm_input_en)):
cache_path = snapshot_download(model_id) cache_path = snapshot_download(model_id)
model = PalmForTextGeneration.from_pretrained(cache_path) model = PalmForTextGeneration.from_pretrained(cache_path)
preprocessor = TextGenerationPreprocessor( preprocessor = TextGenerationPreprocessor(


Loading…
Cancel
Save