yichang.zyc yingda.chen 3 years ago
parent
commit
b537bb8c27
2 changed files with 10 additions and 8 deletions
  1. +4
    -4
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +6
    -4
      tests/pipelines/test_ofa_tasks.py

+ 4
- 4
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -112,8 +112,6 @@ class OfaForAllTasks(TorchModel):
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in ret and len(ret[key]) == 1:
ret[key] = ret[key][0]
if key not in ret:
ret[key] = None
return ret
@@ -121,8 +119,10 @@ class OfaForAllTasks(TorchModel):
def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
if self.cfg.task == Tasks.image_captioning:
caption = input[OutputKeys.CAPTION]
caption = caption.translate(self.transtab).strip()
caption = [
cap.translate(self.transtab).strip()
for cap in input[OutputKeys.CAPTION]
]
input[OutputKeys.CAPTION] = caption
return input



+ 6
- 4
tests/pipelines/test_ofa_tasks.py View File

@@ -147,8 +147,10 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = ofa_pipe(input)
print(result)
image_name = image.split('/')[-2]
self.save_img(image, result[OutputKeys.BOXES],
osp.join('large_en_model_' + image_name + '.png'))
self.save_img(
image,
result[OutputKeys.BOXES][0], # just one box
osp.join('large_en_model_' + image_name + '.png'))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_with_name(self):
@@ -161,7 +163,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = ofa_pipe(input)
print(result)
image_name = image.split('/')[-2]
self.save_img(image, result[OutputKeys.BOXES],
self.save_img(image, result[OutputKeys.BOXES][0],
osp.join('large_en_name_' + image_name + '.png'))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -174,7 +176,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = ofa_pipe(input)
print(result)
image_name = image.split('/')[-1]
self.save_img(image, result[OutputKeys.BOXES],
self.save_img(image, result[OutputKeys.BOXES][0],
osp.join('large_zh_name_' + image_name))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


Loading…
Cancel
Save