From b537bb8c270bef36b1ea5d8f0c8c3e2df67aff9d Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Wed, 21 Sep 2022 18:57:34 +0800 Subject: [PATCH] fix vg return value Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10207239 --- modelscope/models/multi_modal/ofa_for_all_tasks.py | 8 ++++---- tests/pipelines/test_ofa_tasks.py | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 05950378..45bafde9 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -112,8 +112,6 @@ class OfaForAllTasks(TorchModel): OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, OutputKeys.LABELS, OutputKeys.SCORES ]: - if key in ret and len(ret[key]) == 1: - ret[key] = ret[key][0] if key not in ret: ret[key] = None return ret @@ -121,8 +119,10 @@ class OfaForAllTasks(TorchModel): def postprocess(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: if self.cfg.task == Tasks.image_captioning: - caption = input[OutputKeys.CAPTION] - caption = caption.translate(self.transtab).strip() + caption = [ + cap.translate(self.transtab).strip() + for cap in input[OutputKeys.CAPTION] + ] input[OutputKeys.CAPTION] = caption return input diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 9a72d1ff..e6638dfa 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -147,8 +147,10 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-2] - self.save_img(image, result[OutputKeys.BOXES], - osp.join('large_en_model_' + image_name + '.png')) + self.save_img( + image, + result[OutputKeys.BOXES][0], # just one box + osp.join('large_en_model_' + image_name + '.png')) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_visual_grounding_with_name(self): @@ -161,7 +163,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-2] - self.save_img(image, result[OutputKeys.BOXES], + self.save_img(image, result[OutputKeys.BOXES][0], osp.join('large_en_name_' + image_name + '.png')) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -174,7 +176,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = ofa_pipe(input) print(result) image_name = image.split('/')[-1] - self.save_img(image, result[OutputKeys.BOXES], + self.save_img(image, result[OutputKeys.BOXES][0], osp.join('large_zh_name_' + image_name)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')