From b92e2ca0a05bf45012713ea4f425e5c6a00adf91 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 25 Aug 2022 21:26:51 +0800 Subject: [PATCH] [to #42322933] add vqa and caption finetuning for mplug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 mplug 模型 caption 及 vqa 任务的 finetuning 支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9858028 --- modelscope/metrics/builder.py | 2 + .../multi_modal/mplug/modeling_mplug.py | 110 ++------------- .../models/multi_modal/mplug_for_all_tasks.py | 50 +++++-- .../nlp/gpt3/gpt3_for_text_generation.py | 3 +- .../nlp/palm_v2/palm_for_text_generation.py | 15 +- modelscope/preprocessors/multi_modal.py | 66 +++++---- tests/trainers/test_finetune_mplug.py | 128 ++++++++++++++++++ 7 files changed, 226 insertions(+), 148 deletions(-) create mode 100644 tests/trainers/test_finetune_mplug.py diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index c76fe386..ad41fd87 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -30,6 +30,8 @@ task_default_metrics = { Tasks.image_portrait_enhancement: [Metrics.image_portrait_enhancement_metric], Tasks.video_summarization: [Metrics.video_summarization_metric], + Tasks.image_captioning: [Metrics.text_gen_metric], + Tasks.visual_question_answering: [Metrics.text_gen_metric], } diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index 50622cc0..6311bd31 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1969,71 +1969,6 @@ class MPlug(PreTrainedModel): [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) return torch.index_select(x, dim, order_index.to(x.device)) - def rank_answer(self, question_states, question_atts, answer_ids, - answer_atts, k): - - num_ques = question_states.size(0) - start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token - - start_output = self.text_decoder( - start_ids, - encoder_hidden_states=question_states, - encoder_attention_mask=question_atts, - return_dict=True, - reduction='none') - logits = start_output.logits[:, 0, :] # first token's logit - - # topk_probs: top-k probability - # topk_ids: [num_question, k] - answer_first_token = answer_ids[:, 1] - prob_first_token = F.softmax( - logits, dim=1).index_select( - dim=1, index=answer_first_token) - topk_probs, topk_ids = prob_first_token.topk(k, dim=1) - - # answer input: [num_question*k, answer_len] - input_ids = [] - input_atts = [] - for b, topk_id in enumerate(topk_ids): - input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) - input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) - input_ids = torch.cat(input_ids, dim=0) - input_atts = torch.cat(input_atts, dim=0) - - targets_ids = input_ids.masked_fill( - input_ids == self.tokenizer.pad_token_id, -100) - - # repeat encoder's output for top-k answers - question_states = self._tile(question_states, 0, k) - question_atts = self._tile(question_atts, 0, k) - - output = self.text_decoder( - input_ids, - attention_mask=input_atts, - encoder_hidden_states=question_states, - encoder_attention_mask=question_atts, - labels=targets_ids, - return_dict=True, - reduction='none') - - answer_loss = output.loss - answer_loss = answer_loss.view(input_ids.size(0), -1) - - # topk_prob: first token probability - topk_probs = topk_probs.view(-1, 1) - log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) - - # re-calculate log probabilities for the answer sequences using chain rule - log_probs_sum = log_probs.sum(1) - log_probs_sum = log_probs_sum.view(num_ques, k) - - topk_probs = F.softmax(log_probs_sum, dim=-1) - # get top-k after re-ranking - topk_probs, rerank_id = topk_probs.topk(k, dim=1) - topk_ids = torch.gather(topk_ids, 1, rerank_id) - - return topk_ids, topk_probs - class MPlugForVisualQuestionAnswering(MPlug): @@ -2111,6 +2046,8 @@ class MPlugForVisualQuestionAnswering(MPlug): merge_text_attention = torch.cat( [image_atts, question.attention_mask], 1) + if k is None: + k = [1] * question_output.shape[0] question_states = [] question_atts = [] for b, n in enumerate(k): @@ -2177,6 +2114,8 @@ class MPlugForVisualQuestionAnswering(MPlug): return_dict=True, reduction='none', ) + if weights is None: + weights = 1 loss = weights * answer_output.loss loss = loss.sum() / image.size(0) @@ -2262,50 +2201,17 @@ class MPLUGForImageCaption(MPlug): if train: answer_targets = answer.input_ids.masked_fill( answer.input_ids == self.tokenizer.pad_token_id, -100) - text_output = self.text_encoder( - question.input_ids, - attention_mask=question.attention_mask, - return_dict=True) - text_embeds = text_output.last_hidden_state - fusion_output = self.fusion_encoder( - encoder_embeds=text_embeds, - attention_mask=question.attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=False) - - image_output, question_output = fusion_output - - question_output = torch.cat([image_output, question_output], 1) - merge_text_attention = torch.cat( - [image_atts, question.attention_mask], 1) - answer_output = self.text_decoder( answer.input_ids, attention_mask=answer.attention_mask, - encoder_hidden_states=question_output, - encoder_attention_mask=merge_text_attention, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, labels=answer_targets, return_dict=True, reduction='none') loss = answer_output.loss + return loss else: - text_output = self.text_encoder( - question.input_ids, - attention_mask=question.attention_mask, - return_dict=True) - text_embeds = text_output.last_hidden_state - fusion_output = self.fusion_encoder( - encoder_embeds=text_embeds, - attention_mask=question.attention_mask, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=False) - image_output, question_output = fusion_output - question_output = torch.cat([image_output, question_output], 1) - merge_text_attention = torch.cat( - [image_atts, question.attention_mask], 1) - topk_ids, topk_probs = self.generation(question_output, - merge_text_attention) + topk_ids, topk_probs = self.generation(image_embeds, image_atts) return topk_ids, topk_probs diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index bb5a9c46..fb460714 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List from modelscope.metainfo import Models from modelscope.models import TorchModel @@ -25,12 +25,6 @@ class MPlugForAllTasks(TorchModel): self.model = MPlug.from_pretrained(model_dir) self.tokenizer = self.model.tokenizer - def train(self): - return self.model.train() - - def eval(self): - return self.model.eval() - def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model @@ -45,13 +39,43 @@ class MPlugForAllTasks(TorchModel): } """ - topk_ids, _ = self.model(**input) replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) - pred_string = self.tokenizer.decode(topk_ids[0][0]) - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string = pred_string.strip() - return pred_string + if not self.training and 'answer_input_ids' not in input: + topk_ids, _ = self.model(**input) + pred_string: str = self.tokenizer.decode(topk_ids[0][0]) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string = pred_string.strip() + return pred_string + else: + import addict + question = addict.Dict( + input_ids=input['question_input_ids'], + attention_mask=input['question_attention_mask']) + answer = addict.Dict( + input_ids=input['answer_input_ids'], + attention_mask=input['answer_attention_mask']) + output = self.model( + input['image'], question, answer, train=self.training) + if self.training: + return {'loss': output} + topk_ids, _ = output + preds: List[str] = [ + self.tokenizer.decode(batch[0]) for batch in topk_ids + ] + for i in range(len(preds)): + for _old, _new in replace_tokens_bert: + preds[i] = preds[i].replace(_old, _new) + preds[i] = preds[i].strip() + tgts: List[str] = [ + self.tokenizer.decode(batch) + for batch in input['answer_input_ids'].cpu().numpy().tolist() + ] + for i in range(len(tgts)): + for _old, _new in replace_tokens_bert: + tgts[i] = tgts[i].replace(_old, _new) + preds[i] = preds[i].strip() + return {'preds': preds, 'tgts': tgts} diff --git a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py index 7cff9ad4..fe1402e8 100644 --- a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py +++ b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py @@ -60,5 +60,6 @@ class GPT3ForTextGeneration(TorchModel): sample_output = self.model.generate(**gen_params) return { OutputKeys.TEXT: - self.tokenizer.decode(sample_output[0], skip_special_tokens=True) + self.tokenizer.decode(sample_output[0], + skip_special_tokens=True).replace(' ', '') } diff --git a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py index e432cc58..98aa56c7 100644 --- a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py @@ -29,20 +29,19 @@ class PalmForTextGeneration(TorchModel): self.generator = Translator(self.model) def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: - replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), - ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), - ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', + ''), + (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), + ('[CLS]', ''), ('[UNK]', ''), (' ', '')) replace_tokens_roberta = ((r' +', ' '), ('', '. '), ('', ''), ('', ''), ('', ''), ('', ' '), ('', '. ')) + replace_tokens = replace_tokens_roberta \ + if self.model.config.encoder == 'roberta' else replace_tokens_bert strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] - for _old, _new in replace_tokens_bert: + for _old, _new in replace_tokens: strings = [s.replace(_old, _new) for s in strings] - for _old, _new in replace_tokens_roberta: - strings = [s.replace(_old, _new) for s in strings] - for s in strings: - s.strip() return strings def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 56b10c3a..4f0cb977 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -9,7 +9,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors from modelscope.pipelines.base import Input from modelscope.utils.config import Config -from modelscope.utils.constant import Fields, ModelFile, Tasks +from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks from .base import Preprocessor from .builder import PREPROCESSORS from .ofa import * # noqa @@ -91,9 +91,16 @@ class OfaPreprocessor(Preprocessor): Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) class MPlugPreprocessor(Preprocessor): - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, + model_dir: str, + mode: str = ModeKeys.INFERENCE, + tokenizer_max_length: int = 25, + *args, + **kwargs): super().__init__(*args, **kwargs) self.model_dir = model_dir + self.mode = mode + self.tokenizer_max_length = tokenizer_max_length self._tokenizer = None self._patch_resize_transform = None @@ -128,40 +135,51 @@ class MPlugPreprocessor(Preprocessor): def __call__(self, *args, **kwargs): call_mapping = { - Tasks.visual_question_answering: self.vqa_call, - Tasks.image_captioning: self.caption_call + Tasks.visual_question_answering: self.image_text_call, + Tasks.image_captioning: self.image_text_call, } self.cfg = Config.from_file( osp.join(self.model_dir, ModelFile.CONFIGURATION)) return call_mapping[self.cfg.task](*args, **kwargs) - def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: - image: Image.Image = data[0] if isinstance(data, - tuple) else data['image'] - question: str = data[1] if isinstance(data, - tuple) else data['question'] - image = image.convert('RGB') - image = self.patch_resize_transform(image) - image = torch.stack([image], dim=0) - question = self.tokenizer([question.lower()], - padding='longest', - return_tensors='pt') - - return {'image': image, 'question': question, 'train': False} - - def caption_call( + def image_text_call( self, data: Union[Image.Image, tuple, Dict[str, Any]]) -> Dict[str, Any]: - if isinstance(data, Image.Image): + if isinstance(data, (Image.Image, str)): image = data elif isinstance(data, tuple): image = data[0] else: image = data['image'] + if isinstance(image, str): + image = Image.open(image) + question = '' if self.cfg.task != Tasks.visual_question_answering \ + else data[1 if isinstance(data, tuple) else 'question'] image = image.convert('RGB') image = self.patch_resize_transform(image) - image = torch.stack([image], dim=0) - question = self.tokenizer('', return_tensors='pt') - - return {'image': image, 'question': question, 'train': False} + question = self.tokenizer( + question.lower(), + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + + if self.mode == ModeKeys.INFERENCE: + image = torch.stack([image], dim=0) + return {'image': image, 'question': question, 'train': False} + else: + answer = data['answer'] + answer = self.tokenizer( + answer, + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + return { + 'image': image, + 'question_input_ids': question.input_ids.squeeze(), + 'question_attention_mask': question.attention_mask.squeeze(), + 'answer_input_ids': answer.input_ids.squeeze(), + 'answer_attention_mask': answer.attention_mask.squeeze(), + } diff --git a/tests/trainers/test_finetune_mplug.py b/tests/trainers/test_finetune_mplug.py new file mode 100644 index 00000000..5776141c --- /dev/null +++ b/tests/trainers/test_finetune_mplug.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from PIL import Image + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.multi_modal import MPlugForAllTasks +from modelscope.msdatasets import MsDataset +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestFinetuneMPlug(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + datadict = MsDataset.load('coco_captions_small_slice') + self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( + lambda _: { + 'question': 'what the picture describes?' + }).rename_column('image:FILE', + 'image').rename_column('answer:Value', 'answer')) + self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map( + lambda _: { + 'question': 'what the picture describes?' + }).rename_column('image:FILE', + 'image').rename_column('answer:Value', 'answer')) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_caption(self): + + kwargs = dict( + model='damo/mplug_image-captioning_coco_base_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(3): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_caption_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download( + 'damo/mplug_image-captioning_coco_base_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_vqa(self): + + kwargs = dict( + model='damo/mplug_visual-question-answering_coco_large_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(3): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_vqa_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download( + 'damo/mplug_visual-question-answering_coco_large_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()