Browse Source

[to #42322933] Fix bug for TextGenerationPreprocessor

1. 修复 TextGenerationPreprocessor 中 padding 设置为 True 导致输入序列长度不同时训练报错的 bug
2. 将 vqa pipeline 的输入从 image_path 调整为 Image.Image 类型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9630285
master
hemu.zp yingda.chen 3 years ago
parent
commit
56c3cd03a9
4 changed files with 23 additions and 11 deletions
  1. +7
    -1
      modelscope/models/nlp/gpt3/gpt3_for_text_generation.py
  2. +6
    -4
      modelscope/preprocessors/multi_modal.py
  3. +1
    -1
      modelscope/preprocessors/nlp.py
  4. +9
    -5
      tests/pipelines/test_visual_question_answering.py

+ 7
- 1
modelscope/models/nlp/gpt3/gpt3_for_text_generation.py View File

@@ -43,8 +43,14 @@ class GPT3ForTextGeneration(TorchModel):


def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
assert 'input_ids' in input, "generate function must accept 'input_ids' key" assert 'input_ids' in input, "generate function must accept 'input_ids' key"
input_ids = input['input_ids']
if 'attention_mask' in input:
attention_mask = input['attention_mask']
input_ids = input_ids[0][attention_mask[0].nonzero()] \
.squeeze().unsqueeze(0)

gen_params = dict() gen_params = dict()
gen_params['inputs'] = input['input_ids']
gen_params['inputs'] = input_ids
gen_params['do_sample'] = input.pop('do_sample', True) gen_params['do_sample'] = input.pop('do_sample', True)
gen_params['max_length'] = input.pop('max_length', 128) gen_params['max_length'] = input.pop('max_length', 128)
gen_params['top_k'] = input.pop('top_k', 10) gen_params['top_k'] = input.pop('top_k', 10)


+ 6
- 4
modelscope/preprocessors/multi_modal.py View File

@@ -118,10 +118,12 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor):
transforms.Normalize(mean=mean, std=std), transforms.Normalize(mean=mean, std=std),
]) ])


def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image, question = data['image'], data['question']
image = Image.open(image).convert('RGB') if isinstance(image,
str) else image
def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]:
image: Image.Image = data[0] if isinstance(data,
tuple) else data['image']
question: str = data[1] if isinstance(data,
tuple) else data['question']
image = image.convert('RGB')
image = self.patch_resize_transform(image) image = self.patch_resize_transform(image)
image = torch.stack([image], dim=0) image = torch.stack([image], dim=0)
question = self.tokenizer([question.lower()], question = self.tokenizer([question.lower()],


+ 1
- 1
modelscope/preprocessors/nlp.py View File

@@ -286,7 +286,7 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
self.tokenizer = self.build_tokenizer( self.tokenizer = self.build_tokenizer(
model_dir) if tokenizer is None else tokenizer model_dir) if tokenizer is None else tokenizer
kwargs['truncation'] = kwargs.get('truncation', True) kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get('padding', True)
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
False) False)
kwargs['max_length'] = kwargs.pop('sequence_length', 128) kwargs['max_length'] = kwargs.pop('sequence_length', 128)


+ 9
- 5
tests/pipelines/test_visual_question_answering.py View File

@@ -2,6 +2,8 @@


import unittest import unittest


from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model from modelscope.models import Model
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
@@ -13,11 +15,13 @@ from modelscope.utils.test_utils import test_level




class VisualQuestionAnsweringTest(unittest.TestCase): class VisualQuestionAnsweringTest(unittest.TestCase):
model_id = 'damo/mplug_visual-question-answering_coco_large_en'
input_vqa = {
'image': 'data/test/images/image_mplug_vqa.jpg',
'question': 'What is the woman doing?',
}

def setUp(self):
self.model_id = 'damo/mplug_visual-question-answering_coco_large_en'
self.input_vqa = {
'image': Image.open('data/test/images/image_mplug_vqa.jpg'),
'question': 'What is the woman doing?',
}


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self): def test_run(self):


Loading…
Cancel
Save