diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 05950378..45bafde9 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -112,8 +112,6 @@ class OfaForAllTasks(TorchModel): OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, OutputKeys.LABELS, OutputKeys.SCORES ]: - if key in ret and len(ret[key]) == 1: - ret[key] = ret[key][0] if key not in ret: ret[key] = None return ret @@ -121,8 +119,10 @@ class OfaForAllTasks(TorchModel): def postprocess(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: if self.cfg.task == Tasks.image_captioning: - caption = input[OutputKeys.CAPTION] - caption = caption.translate(self.transtab).strip() + caption = [ + cap.translate(self.transtab).strip() + for cap in input[OutputKeys.CAPTION] + ] input[OutputKeys.CAPTION] = caption return input diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 9a72d1ff..e6638dfa 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -147,8 +147,10 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-2] - self.save_img(image, result[OutputKeys.BOXES], - osp.join('large_en_model_' + image_name + '.png')) + self.save_img( + image, + result[OutputKeys.BOXES][0], # just one box + osp.join('large_en_model_' + image_name + '.png')) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_visual_grounding_with_name(self): @@ -161,7 +163,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-2] - self.save_img(image, result[OutputKeys.BOXES], + self.save_img(image, result[OutputKeys.BOXES][0], osp.join('large_en_name_' + image_name + '.png')) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -174,7 +176,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-1] - self.save_img(image, result[OutputKeys.BOXES], + self.save_img(image, result[OutputKeys.BOXES][0], osp.join('large_zh_name_' + image_name)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')