1. 修复 TextGenerationPreprocessor 中 padding 设置为 True 导致输入序列长度不同时训练报错的 bug
2. 将 vqa pipeline 的输入从 image_path 调整为 Image.Image 类型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9630285
master
| @@ -43,8 +43,14 @@ class GPT3ForTextGeneration(TorchModel): | |||||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | ||||
| assert 'input_ids' in input, "generate function must accept 'input_ids' key" | assert 'input_ids' in input, "generate function must accept 'input_ids' key" | ||||
| input_ids = input['input_ids'] | |||||
| if 'attention_mask' in input: | |||||
| attention_mask = input['attention_mask'] | |||||
| input_ids = input_ids[0][attention_mask[0].nonzero()] \ | |||||
| .squeeze().unsqueeze(0) | |||||
| gen_params = dict() | gen_params = dict() | ||||
| gen_params['inputs'] = input['input_ids'] | |||||
| gen_params['inputs'] = input_ids | |||||
| gen_params['do_sample'] = input.pop('do_sample', True) | gen_params['do_sample'] = input.pop('do_sample', True) | ||||
| gen_params['max_length'] = input.pop('max_length', 128) | gen_params['max_length'] = input.pop('max_length', 128) | ||||
| gen_params['top_k'] = input.pop('top_k', 10) | gen_params['top_k'] = input.pop('top_k', 10) | ||||
| @@ -118,10 +118,12 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||||
| transforms.Normalize(mean=mean, std=std), | transforms.Normalize(mean=mean, std=std), | ||||
| ]) | ]) | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image, question = data['image'], data['question'] | |||||
| image = Image.open(image).convert('RGB') if isinstance(image, | |||||
| str) else image | |||||
| def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: | |||||
| image: Image.Image = data[0] if isinstance(data, | |||||
| tuple) else data['image'] | |||||
| question: str = data[1] if isinstance(data, | |||||
| tuple) else data['question'] | |||||
| image = image.convert('RGB') | |||||
| image = self.patch_resize_transform(image) | image = self.patch_resize_transform(image) | ||||
| image = torch.stack([image], dim=0) | image = torch.stack([image], dim=0) | ||||
| question = self.tokenizer([question.lower()], | question = self.tokenizer([question.lower()], | ||||
| @@ -286,7 +286,7 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| self.tokenizer = self.build_tokenizer( | self.tokenizer = self.build_tokenizer( | ||||
| model_dir) if tokenizer is None else tokenizer | model_dir) if tokenizer is None else tokenizer | ||||
| kwargs['truncation'] = kwargs.get('truncation', True) | kwargs['truncation'] = kwargs.get('truncation', True) | ||||
| kwargs['padding'] = kwargs.get('padding', True) | |||||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | ||||
| False) | False) | ||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | kwargs['max_length'] = kwargs.pop('sequence_length', 128) | ||||
| @@ -2,6 +2,8 @@ | |||||
| import unittest | import unittest | ||||
| from PIL import Image | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | ||||
| @@ -13,11 +15,13 @@ from modelscope.utils.test_utils import test_level | |||||
| class VisualQuestionAnsweringTest(unittest.TestCase): | class VisualQuestionAnsweringTest(unittest.TestCase): | ||||
| model_id = 'damo/mplug_visual-question-answering_coco_large_en' | |||||
| input_vqa = { | |||||
| 'image': 'data/test/images/image_mplug_vqa.jpg', | |||||
| 'question': 'What is the woman doing?', | |||||
| } | |||||
| def setUp(self): | |||||
| self.model_id = 'damo/mplug_visual-question-answering_coco_large_en' | |||||
| self.input_vqa = { | |||||
| 'image': Image.open('data/test/images/image_mplug_vqa.jpg'), | |||||
| 'question': 'What is the woman doing?', | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run(self): | def test_run(self): | ||||