|
|
@@ -34,6 +34,61 @@ class TextGenerationTest(unittest.TestCase): |
|
|
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' |
|
|
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' |
|
|
self.gpt3_input = '我很好奇' |
|
|
self.gpt3_input = '我很好奇' |
|
|
|
|
|
|
|
|
|
|
|
def run_pipeline_with_model_instance(self, model_id, input): |
|
|
|
|
|
model = Model.from_pretrained(model_id) |
|
|
|
|
|
preprocessor = TextGenerationPreprocessor( |
|
|
|
|
|
model.model_dir, |
|
|
|
|
|
model.tokenizer, |
|
|
|
|
|
first_sequence='sentence', |
|
|
|
|
|
second_sequence=None) |
|
|
|
|
|
pipeline_ins = pipeline( |
|
|
|
|
|
task=Tasks.text_generation, model=model, preprocessor=preprocessor) |
|
|
|
|
|
print(pipeline_ins(input)) |
|
|
|
|
|
|
|
|
|
|
|
def run_pipeline_with_model_id(self, model_id, input): |
|
|
|
|
|
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) |
|
|
|
|
|
print(pipeline_ins(input)) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
|
|
def test_palm_zh_with_model_name(self): |
|
|
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_zh, |
|
|
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
|
|
def test_palm_en_with_model_name(self): |
|
|
|
|
|
self.run_pipeline_with_model_id(self.palm_model_id_en, |
|
|
|
|
|
self.palm_input_en) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
|
|
def test_gpt_base_with_model_name(self): |
|
|
|
|
|
self.run_pipeline_with_model_id(self.gpt3_base_model_id, |
|
|
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
|
|
def test_gpt_large_with_model_name(self): |
|
|
|
|
|
self.run_pipeline_with_model_id(self.gpt3_large_model_id, |
|
|
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
|
|
def test_palm_zh_with_model_instance(self): |
|
|
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_zh, |
|
|
|
|
|
self.palm_input_zh) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
|
|
def test_palm_en_with_model_instance(self): |
|
|
|
|
|
self.run_pipeline_with_model_instance(self.palm_model_id_en, |
|
|
|
|
|
self.palm_input_en) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
|
|
def test_gpt_base_with_model_instance(self): |
|
|
|
|
|
self.run_pipeline_with_model_instance(self.gpt3_base_model_id, |
|
|
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
|
|
def test_gpt_large_with_model_instance(self): |
|
|
|
|
|
self.run_pipeline_with_model_instance(self.gpt3_large_model_id, |
|
|
|
|
|
self.gpt3_input) |
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_run_palm(self): |
|
|
def test_run_palm(self): |
|
|
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), |
|
|
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), |
|
|
@@ -68,33 +123,6 @@ class TextGenerationTest(unittest.TestCase): |
|
|
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' |
|
|
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
|
|
def test_run_with_model_from_modelhub(self): |
|
|
|
|
|
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), |
|
|
|
|
|
(self.palm_model_id_en, self.palm_input_en), |
|
|
|
|
|
(self.gpt3_base_model_id, self.gpt3_input), |
|
|
|
|
|
(self.gpt3_large_model_id, self.gpt3_input)): |
|
|
|
|
|
model = Model.from_pretrained(model_id) |
|
|
|
|
|
preprocessor = TextGenerationPreprocessor( |
|
|
|
|
|
model.model_dir, |
|
|
|
|
|
model.tokenizer, |
|
|
|
|
|
first_sequence='sentence', |
|
|
|
|
|
second_sequence=None) |
|
|
|
|
|
pipeline_ins = pipeline( |
|
|
|
|
|
task=Tasks.text_generation, |
|
|
|
|
|
model=model, |
|
|
|
|
|
preprocessor=preprocessor) |
|
|
|
|
|
print(pipeline_ins(input)) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
|
|
def test_run_with_model_name(self): |
|
|
|
|
|
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), |
|
|
|
|
|
(self.palm_model_id_en, self.palm_input_en), |
|
|
|
|
|
(self.gpt3_base_model_id, self.gpt3_input), |
|
|
|
|
|
(self.gpt3_large_model_id, self.gpt3_input)): |
|
|
|
|
|
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) |
|
|
|
|
|
print(pipeline_ins(input)) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_run_with_default_model(self): |
|
|
def test_run_with_default_model(self): |
|
|
pipeline_ins = pipeline(task=Tasks.text_generation) |
|
|
pipeline_ins = pipeline(task=Tasks.text_generation) |
|
|
|