Browse Source

[to #42322933]split text generation tests

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9684735

    * split test generation tests
master
yingda.chen 3 years ago
parent
commit
5ea690d743
1 changed files with 55 additions and 27 deletions
  1. +55
    -27
      tests/pipelines/test_text_generation.py

+ 55
- 27
tests/pipelines/test_text_generation.py View File

@@ -34,6 +34,61 @@ class TextGenerationTest(unittest.TestCase):
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
self.gpt3_input = '我很好奇'

def run_pipeline_with_model_instance(self, model_id, input):
model = Model.from_pretrained(model_id)
preprocessor = TextGenerationPreprocessor(
model.model_dir,
model.tokenizer,
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model, preprocessor=preprocessor)
print(pipeline_ins(input))

def run_pipeline_with_model_id(self, model_id, input):
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_en_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_en,
self.palm_input_en)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_gpt_base_with_model_name(self):
self.run_pipeline_with_model_id(self.gpt3_base_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_gpt_large_with_model_name(self):
self.run_pipeline_with_model_id(self.gpt3_large_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_en_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_en,
self.palm_input_en)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_base_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_large_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
@@ -68,33 +123,6 @@ class TextGenerationTest(unittest.TestCase):
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en),
(self.gpt3_base_model_id, self.gpt3_input),
(self.gpt3_large_model_id, self.gpt3_input)):
model = Model.from_pretrained(model_id)
preprocessor = TextGenerationPreprocessor(
model.model_dir,
model.tokenizer,
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.text_generation,
model=model,
preprocessor=preprocessor)
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en),
(self.gpt3_base_model_id, self.gpt3_input),
(self.gpt3_large_model_id, self.gpt3_input)):
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_generation)


Loading…
Cancel
Save