From 56c3cd03a95c40398265736effc588e9f6395268 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 4 Aug 2022 09:30:53 +0800 Subject: [PATCH] [to #42322933] Fix bug for TextGenerationPreprocessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复 TextGenerationPreprocessor 中 padding 设置为 True 导致输入序列长度不同时训练报错的 bug 2. 将 vqa pipeline 的输入从 image_path 调整为 Image.Image 类型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9630285 --- .../models/nlp/gpt3/gpt3_for_text_generation.py | 8 +++++++- modelscope/preprocessors/multi_modal.py | 10 ++++++---- modelscope/preprocessors/nlp.py | 2 +- tests/pipelines/test_visual_question_answering.py | 14 +++++++++----- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py index 6bdcb431..9f6874c6 100644 --- a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py +++ b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py @@ -43,8 +43,14 @@ class GPT3ForTextGeneration(TorchModel): def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: assert 'input_ids' in input, "generate function must accept 'input_ids' key" + input_ids = input['input_ids'] + if 'attention_mask' in input: + attention_mask = input['attention_mask'] + input_ids = input_ids[0][attention_mask[0].nonzero()] \ + .squeeze().unsqueeze(0) + gen_params = dict() - gen_params['inputs'] = input['input_ids'] + gen_params['inputs'] = input_ids gen_params['do_sample'] = input.pop('do_sample', True) gen_params['max_length'] = input.pop('max_length', 128) gen_params['top_k'] = input.pop('top_k', 10) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 2a5cd259..2f62c6af 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -118,10 +118,12 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): transforms.Normalize(mean=mean, std=std), ]) - def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image, question = data['image'], data['question'] - image = Image.open(image).convert('RGB') if isinstance(image, - str) else image + def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: + image: Image.Image = data[0] if isinstance(data, + tuple) else data['image'] + question: str = data[1] if isinstance(data, + tuple) else data['question'] + image = image.convert('RGB') image = self.patch_resize_transform(image) image = torch.stack([image], dim=0) question = self.tokenizer([question.lower()], diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index f0951f38..58ad3dbe 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -286,7 +286,7 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): self.tokenizer = self.build_tokenizer( model_dir) if tokenizer is None else tokenizer kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get('padding', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', False) kwargs['max_length'] = kwargs.pop('sequence_length', 128) diff --git a/tests/pipelines/test_visual_question_answering.py b/tests/pipelines/test_visual_question_answering.py index de7edbba..748a86b9 100644 --- a/tests/pipelines/test_visual_question_answering.py +++ b/tests/pipelines/test_visual_question_answering.py @@ -2,6 +2,8 @@ import unittest +from PIL import Image + from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering @@ -13,11 +15,13 @@ from modelscope.utils.test_utils import test_level class VisualQuestionAnsweringTest(unittest.TestCase): - model_id = 'damo/mplug_visual-question-answering_coco_large_en' - input_vqa = { - 'image': 'data/test/images/image_mplug_vqa.jpg', - 'question': 'What is the woman doing?', - } + + def setUp(self): + self.model_id = 'damo/mplug_visual-question-answering_coco_large_en' + self.input_vqa = { + 'image': Image.open('data/test/images/image_mplug_vqa.jpg'), + 'question': 'What is the woman doing?', + } @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self):