|
|
|
@@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level |
|
|
|
class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
|
|
|
|
def setUp(self) -> None: |
|
|
|
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' |
|
|
|
self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base' |
|
|
|
self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large' |
|
|
|
self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base' |
|
|
|
self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base' |
|
|
|
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' |
|
|
|
self.palm_input_zh = """ |
|
|
|
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: |
|
|
|
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 |
|
|
|
""" |
|
|
|
self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮' |
|
|
|
self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'" |
|
|
|
self.palm_input_en = """ |
|
|
|
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started |
|
|
|
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , |
|
|
|
@@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
print(pipeline_ins(input)) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_palm_zh_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh, |
|
|
|
def test_palm_zh_base_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh_base, |
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
@@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh, |
|
|
|
def test_palm_zh_large_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh_large, |
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_commodity_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity, |
|
|
|
self.palm_input_commodity) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_weather_with_model_name(self): |
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh_weather, |
|
|
|
self.palm_input_weather) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_base_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh_base, |
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_large_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh_large, |
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_commodity_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity, |
|
|
|
self.palm_input_commodity) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_zh_weather_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather, |
|
|
|
self.palm_input_weather) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_palm_en_with_model_instance(self): |
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_en, |
|
|
|
@@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_run_palm(self): |
|
|
|
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), |
|
|
|
(self.palm_model_id_en, self.palm_input_en)): |
|
|
|
for model_id, input in ((self.palm_model_id_zh_base, |
|
|
|
self.palm_input_zh), (self.palm_model_id_en, |
|
|
|
self.palm_input_en)): |
|
|
|
cache_path = snapshot_download(model_id) |
|
|
|
model = PalmForTextGeneration.from_pretrained(cache_path) |
|
|
|
preprocessor = TextGenerationPreprocessor( |
|
|
|
|