Browse Source

[to #42322933] Add ut for mplug and bloom

为新上线的 langboat/bloom-1b4-zh,damo/mplug_visual-question-answering_coco_base_zh,damo/mplug_image-captioning_coco_base_zh 三个模型添加 ut,test_level 设置为 2
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10524221
master
hemu.zp yingda.chen 3 years ago
parent
commit
e4a0e046f9
2 changed files with 28 additions and 13 deletions
  1. +19
    -4
      tests/pipelines/test_mplug_tasks.py
  2. +9
    -9
      tests/pipelines/test_text_generation.py

+ 19
- 4
tests/pipelines/test_mplug_tasks.py View File

@@ -13,10 +13,6 @@ from modelscope.utils.test_utils import test_level

class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = 'visual-question-answering'
self.model_id = 'damo/mplug_visual-question-answering_coco_large_en'

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_model(self):
model = Model.from_pretrained(
@@ -80,6 +76,25 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = pipeline_retrieval(input)
print(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_image_captioning_zh_base_with_name(self):
pipeline_caption = pipeline(
Tasks.image_captioning,
model='damo/mplug_image-captioning_coco_base_zh')
image = Image.open('data/test/images/image_mplug_vqa.jpg')
result = pipeline_caption(image)
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_visual_question_answering_zh_base_with_name(self):
model = 'damo/mplug_visual-question-answering_coco_base_zh'
pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model)
image = Image.open('data/test/images/image_mplug_vqa.jpg')
text = '这个女人在做什么?'
input = {'image': image, 'text': text}
result = pipeline_vqa(input)
print(result)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


+ 9
- 9
tests/pipelines/test_text_generation.py View File

@@ -165,14 +165,16 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(task=Tasks.text_generation)
print(pipeline_ins(self.palm_input_zh))

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_bloom(self):
pipe = pipeline(
task=Tasks.text_generation, model='langboat/bloom-1b4-zh')
print(pipe('中国的首都是'))

@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub")
def test_gpt_neo(self):
pipe = pipeline(
task=Tasks.text_generation, model='Langboat/mengzi-gpt-neo-base')
task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base')
print(
pipe(
'我是',
@@ -182,11 +184,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
max_length=20,
repetition_penalty=0.5))

@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub")
def test_bloom(self):
pipe = pipeline(
task=Tasks.text_generation, model='Langboat/bloom-1b4-zh')
print(pipe('中国的首都是'))
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':


Loading…
Cancel
Save