From 9db97366eb10f2320995b304a4b9c5ee3d2a4c02 Mon Sep 17 00:00:00 2001 From: "yichang.zyc" Date: Tue, 2 Aug 2022 22:17:00 +0800 Subject: [PATCH] [to #42322933] fix sample collate bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9615938 * add init and make demo compatible * make demo compatible * fix comments * add distilled ut * Merge remote-tracking branch 'origin/master' into ofa/bug_fix * fix sample collate bug * Merge remote-tracking branch 'origin/master' into ofa/bug_fix # Conflicts: # tests/pipelines/test_ofa_tasks.py * Merge remote-tracking branch 'origin/master' into ofa/bug_fix * fix sample collate bug --- modelscope/preprocessors/multi_modal.py | 5 ++++- tests/pipelines/test_ofa_tasks.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index a3411a73..2a5cd259 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -76,7 +76,10 @@ class OfaPreprocessor(Preprocessor): else: data = self._build_dict(input) sample = self.preprocess(data) - sample['sample'] = data + str_data = dict() + for k, v in data.items(): + str_data[k] = str(v) + sample['sample'] = str_data return collate_fn([sample], pad_idx=self.tokenizer.pad_token_id, eos_idx=self.tokenizer.eos_token_id) diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 63efa334..1dc7d303 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import unittest +from PIL import Image + from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline @@ -181,8 +183,8 @@ class OfaTasksTest(unittest.TestCase): task=Tasks.image_captioning, model=model, ) - result = img_captioning( - {'image': 'data/test/images/image_captioning.png'}) + image = Image.open('data/test/images/image_captioning.png') + result = img_captioning(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')