添加 mplug 模型 caption 及 vqa 任务的 finetuning 支持
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9858028
master
| @@ -30,6 +30,8 @@ task_default_metrics = { | |||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| [Metrics.image_portrait_enhancement_metric], | [Metrics.image_portrait_enhancement_metric], | ||||
| Tasks.video_summarization: [Metrics.video_summarization_metric], | Tasks.video_summarization: [Metrics.video_summarization_metric], | ||||
| Tasks.image_captioning: [Metrics.text_gen_metric], | |||||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||||
| } | } | ||||
| @@ -1969,71 +1969,6 @@ class MPlug(PreTrainedModel): | |||||
| [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | ||||
| return torch.index_select(x, dim, order_index.to(x.device)) | return torch.index_select(x, dim, order_index.to(x.device)) | ||||
| def rank_answer(self, question_states, question_atts, answer_ids, | |||||
| answer_atts, k): | |||||
| num_ques = question_states.size(0) | |||||
| start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token | |||||
| start_output = self.text_decoder( | |||||
| start_ids, | |||||
| encoder_hidden_states=question_states, | |||||
| encoder_attention_mask=question_atts, | |||||
| return_dict=True, | |||||
| reduction='none') | |||||
| logits = start_output.logits[:, 0, :] # first token's logit | |||||
| # topk_probs: top-k probability | |||||
| # topk_ids: [num_question, k] | |||||
| answer_first_token = answer_ids[:, 1] | |||||
| prob_first_token = F.softmax( | |||||
| logits, dim=1).index_select( | |||||
| dim=1, index=answer_first_token) | |||||
| topk_probs, topk_ids = prob_first_token.topk(k, dim=1) | |||||
| # answer input: [num_question*k, answer_len] | |||||
| input_ids = [] | |||||
| input_atts = [] | |||||
| for b, topk_id in enumerate(topk_ids): | |||||
| input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |||||
| input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |||||
| input_ids = torch.cat(input_ids, dim=0) | |||||
| input_atts = torch.cat(input_atts, dim=0) | |||||
| targets_ids = input_ids.masked_fill( | |||||
| input_ids == self.tokenizer.pad_token_id, -100) | |||||
| # repeat encoder's output for top-k answers | |||||
| question_states = self._tile(question_states, 0, k) | |||||
| question_atts = self._tile(question_atts, 0, k) | |||||
| output = self.text_decoder( | |||||
| input_ids, | |||||
| attention_mask=input_atts, | |||||
| encoder_hidden_states=question_states, | |||||
| encoder_attention_mask=question_atts, | |||||
| labels=targets_ids, | |||||
| return_dict=True, | |||||
| reduction='none') | |||||
| answer_loss = output.loss | |||||
| answer_loss = answer_loss.view(input_ids.size(0), -1) | |||||
| # topk_prob: first token probability | |||||
| topk_probs = topk_probs.view(-1, 1) | |||||
| log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) | |||||
| # re-calculate log probabilities for the answer sequences using chain rule | |||||
| log_probs_sum = log_probs.sum(1) | |||||
| log_probs_sum = log_probs_sum.view(num_ques, k) | |||||
| topk_probs = F.softmax(log_probs_sum, dim=-1) | |||||
| # get top-k after re-ranking | |||||
| topk_probs, rerank_id = topk_probs.topk(k, dim=1) | |||||
| topk_ids = torch.gather(topk_ids, 1, rerank_id) | |||||
| return topk_ids, topk_probs | |||||
| class MPlugForVisualQuestionAnswering(MPlug): | class MPlugForVisualQuestionAnswering(MPlug): | ||||
| @@ -2111,6 +2046,8 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||||
| merge_text_attention = torch.cat( | merge_text_attention = torch.cat( | ||||
| [image_atts, question.attention_mask], 1) | [image_atts, question.attention_mask], 1) | ||||
| if k is None: | |||||
| k = [1] * question_output.shape[0] | |||||
| question_states = [] | question_states = [] | ||||
| question_atts = [] | question_atts = [] | ||||
| for b, n in enumerate(k): | for b, n in enumerate(k): | ||||
| @@ -2177,6 +2114,8 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||||
| return_dict=True, | return_dict=True, | ||||
| reduction='none', | reduction='none', | ||||
| ) | ) | ||||
| if weights is None: | |||||
| weights = 1 | |||||
| loss = weights * answer_output.loss | loss = weights * answer_output.loss | ||||
| loss = loss.sum() / image.size(0) | loss = loss.sum() / image.size(0) | ||||
| @@ -2262,50 +2201,17 @@ class MPLUGForImageCaption(MPlug): | |||||
| if train: | if train: | ||||
| answer_targets = answer.input_ids.masked_fill( | answer_targets = answer.input_ids.masked_fill( | ||||
| answer.input_ids == self.tokenizer.pad_token_id, -100) | answer.input_ids == self.tokenizer.pad_token_id, -100) | ||||
| text_output = self.text_encoder( | |||||
| question.input_ids, | |||||
| attention_mask=question.attention_mask, | |||||
| return_dict=True) | |||||
| text_embeds = text_output.last_hidden_state | |||||
| fusion_output = self.fusion_encoder( | |||||
| encoder_embeds=text_embeds, | |||||
| attention_mask=question.attention_mask, | |||||
| encoder_hidden_states=image_embeds, | |||||
| encoder_attention_mask=image_atts, | |||||
| return_dict=False) | |||||
| image_output, question_output = fusion_output | |||||
| question_output = torch.cat([image_output, question_output], 1) | |||||
| merge_text_attention = torch.cat( | |||||
| [image_atts, question.attention_mask], 1) | |||||
| answer_output = self.text_decoder( | answer_output = self.text_decoder( | ||||
| answer.input_ids, | answer.input_ids, | ||||
| attention_mask=answer.attention_mask, | attention_mask=answer.attention_mask, | ||||
| encoder_hidden_states=question_output, | |||||
| encoder_attention_mask=merge_text_attention, | |||||
| encoder_hidden_states=image_embeds, | |||||
| encoder_attention_mask=image_atts, | |||||
| labels=answer_targets, | labels=answer_targets, | ||||
| return_dict=True, | return_dict=True, | ||||
| reduction='none') | reduction='none') | ||||
| loss = answer_output.loss | loss = answer_output.loss | ||||
| return loss | return loss | ||||
| else: | else: | ||||
| text_output = self.text_encoder( | |||||
| question.input_ids, | |||||
| attention_mask=question.attention_mask, | |||||
| return_dict=True) | |||||
| text_embeds = text_output.last_hidden_state | |||||
| fusion_output = self.fusion_encoder( | |||||
| encoder_embeds=text_embeds, | |||||
| attention_mask=question.attention_mask, | |||||
| encoder_hidden_states=image_embeds, | |||||
| encoder_attention_mask=image_atts, | |||||
| return_dict=False) | |||||
| image_output, question_output = fusion_output | |||||
| question_output = torch.cat([image_output, question_output], 1) | |||||
| merge_text_attention = torch.cat( | |||||
| [image_atts, question.attention_mask], 1) | |||||
| topk_ids, topk_probs = self.generation(question_output, | |||||
| merge_text_attention) | |||||
| topk_ids, topk_probs = self.generation(image_embeds, image_atts) | |||||
| return topk_ids, topk_probs | return topk_ids, topk_probs | ||||
| @@ -1,4 +1,4 @@ | |||||
| from typing import Dict | |||||
| from typing import Dict, List | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
| @@ -25,12 +25,6 @@ class MPlugForAllTasks(TorchModel): | |||||
| self.model = MPlug.from_pretrained(model_dir) | self.model = MPlug.from_pretrained(model_dir) | ||||
| self.tokenizer = self.model.tokenizer | self.tokenizer = self.model.tokenizer | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -45,13 +39,43 @@ class MPlugForAllTasks(TorchModel): | |||||
| } | } | ||||
| """ | """ | ||||
| topk_ids, _ = self.model(**input) | |||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | ||||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | ||||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | ||||
| pred_string = self.tokenizer.decode(topk_ids[0][0]) | |||||
| for _old, _new in replace_tokens_bert: | |||||
| pred_string = pred_string.replace(_old, _new) | |||||
| pred_string = pred_string.strip() | |||||
| return pred_string | |||||
| if not self.training and 'answer_input_ids' not in input: | |||||
| topk_ids, _ = self.model(**input) | |||||
| pred_string: str = self.tokenizer.decode(topk_ids[0][0]) | |||||
| for _old, _new in replace_tokens_bert: | |||||
| pred_string = pred_string.replace(_old, _new) | |||||
| pred_string = pred_string.strip() | |||||
| return pred_string | |||||
| else: | |||||
| import addict | |||||
| question = addict.Dict( | |||||
| input_ids=input['question_input_ids'], | |||||
| attention_mask=input['question_attention_mask']) | |||||
| answer = addict.Dict( | |||||
| input_ids=input['answer_input_ids'], | |||||
| attention_mask=input['answer_attention_mask']) | |||||
| output = self.model( | |||||
| input['image'], question, answer, train=self.training) | |||||
| if self.training: | |||||
| return {'loss': output} | |||||
| topk_ids, _ = output | |||||
| preds: List[str] = [ | |||||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||||
| ] | |||||
| for i in range(len(preds)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| preds[i] = preds[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| tgts: List[str] = [ | |||||
| self.tokenizer.decode(batch) | |||||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||||
| ] | |||||
| for i in range(len(tgts)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| tgts[i] = tgts[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| return {'preds': preds, 'tgts': tgts} | |||||
| @@ -60,5 +60,6 @@ class GPT3ForTextGeneration(TorchModel): | |||||
| sample_output = self.model.generate(**gen_params) | sample_output = self.model.generate(**gen_params) | ||||
| return { | return { | ||||
| OutputKeys.TEXT: | OutputKeys.TEXT: | ||||
| self.tokenizer.decode(sample_output[0], skip_special_tokens=True) | |||||
| self.tokenizer.decode(sample_output[0], | |||||
| skip_special_tokens=True).replace(' ', '') | |||||
| } | } | ||||
| @@ -29,20 +29,19 @@ class PalmForTextGeneration(TorchModel): | |||||
| self.generator = Translator(self.model) | self.generator = Translator(self.model) | ||||
| def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | ||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', | |||||
| ''), | |||||
| (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), | |||||
| ('[CLS]', ''), ('[UNK]', ''), (' ', '')) | |||||
| replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), | replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), | ||||
| ('<pad>', ''), ('<s>', ''), ('</s>', ''), | ('<pad>', ''), ('<s>', ''), ('</s>', ''), | ||||
| ('<unk>', ' '), ('<q>', '. ')) | ('<unk>', ' '), ('<q>', '. ')) | ||||
| replace_tokens = replace_tokens_roberta \ | |||||
| if self.model.config.encoder == 'roberta' else replace_tokens_bert | |||||
| strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | ||||
| for _old, _new in replace_tokens_bert: | |||||
| for _old, _new in replace_tokens: | |||||
| strings = [s.replace(_old, _new) for s in strings] | strings = [s.replace(_old, _new) for s in strings] | ||||
| for _old, _new in replace_tokens_roberta: | |||||
| strings = [s.replace(_old, _new) for s in strings] | |||||
| for s in strings: | |||||
| s.strip() | |||||
| return strings | return strings | ||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| @@ -9,7 +9,7 @@ from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | |||||
| from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| from .ofa import * # noqa | from .ofa import * # noqa | ||||
| @@ -91,9 +91,16 @@ class OfaPreprocessor(Preprocessor): | |||||
| Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | ||||
| class MPlugPreprocessor(Preprocessor): | class MPlugPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| mode: str = ModeKeys.INFERENCE, | |||||
| tokenizer_max_length: int = 25, | |||||
| *args, | |||||
| **kwargs): | |||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.mode = mode | |||||
| self.tokenizer_max_length = tokenizer_max_length | |||||
| self._tokenizer = None | self._tokenizer = None | ||||
| self._patch_resize_transform = None | self._patch_resize_transform = None | ||||
| @@ -128,40 +135,51 @@ class MPlugPreprocessor(Preprocessor): | |||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| call_mapping = { | call_mapping = { | ||||
| Tasks.visual_question_answering: self.vqa_call, | |||||
| Tasks.image_captioning: self.caption_call | |||||
| Tasks.visual_question_answering: self.image_text_call, | |||||
| Tasks.image_captioning: self.image_text_call, | |||||
| } | } | ||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(self.model_dir, ModelFile.CONFIGURATION)) | osp.join(self.model_dir, ModelFile.CONFIGURATION)) | ||||
| return call_mapping[self.cfg.task](*args, **kwargs) | return call_mapping[self.cfg.task](*args, **kwargs) | ||||
| def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: | |||||
| image: Image.Image = data[0] if isinstance(data, | |||||
| tuple) else data['image'] | |||||
| question: str = data[1] if isinstance(data, | |||||
| tuple) else data['question'] | |||||
| image = image.convert('RGB') | |||||
| image = self.patch_resize_transform(image) | |||||
| image = torch.stack([image], dim=0) | |||||
| question = self.tokenizer([question.lower()], | |||||
| padding='longest', | |||||
| return_tensors='pt') | |||||
| return {'image': image, 'question': question, 'train': False} | |||||
| def caption_call( | |||||
| def image_text_call( | |||||
| self, data: Union[Image.Image, tuple, | self, data: Union[Image.Image, tuple, | ||||
| Dict[str, Any]]) -> Dict[str, Any]: | Dict[str, Any]]) -> Dict[str, Any]: | ||||
| if isinstance(data, Image.Image): | |||||
| if isinstance(data, (Image.Image, str)): | |||||
| image = data | image = data | ||||
| elif isinstance(data, tuple): | elif isinstance(data, tuple): | ||||
| image = data[0] | image = data[0] | ||||
| else: | else: | ||||
| image = data['image'] | image = data['image'] | ||||
| if isinstance(image, str): | |||||
| image = Image.open(image) | |||||
| question = '' if self.cfg.task != Tasks.visual_question_answering \ | |||||
| else data[1 if isinstance(data, tuple) else 'question'] | |||||
| image = image.convert('RGB') | image = image.convert('RGB') | ||||
| image = self.patch_resize_transform(image) | image = self.patch_resize_transform(image) | ||||
| image = torch.stack([image], dim=0) | |||||
| question = self.tokenizer('', return_tensors='pt') | |||||
| return {'image': image, 'question': question, 'train': False} | |||||
| question = self.tokenizer( | |||||
| question.lower(), | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=self.tokenizer_max_length, | |||||
| return_tensors='pt') | |||||
| if self.mode == ModeKeys.INFERENCE: | |||||
| image = torch.stack([image], dim=0) | |||||
| return {'image': image, 'question': question, 'train': False} | |||||
| else: | |||||
| answer = data['answer'] | |||||
| answer = self.tokenizer( | |||||
| answer, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=self.tokenizer_max_length, | |||||
| return_tensors='pt') | |||||
| return { | |||||
| 'image': image, | |||||
| 'question_input_ids': question.input_ids.squeeze(), | |||||
| 'question_attention_mask': question.attention_mask.squeeze(), | |||||
| 'answer_input_ids': answer.input_ids.squeeze(), | |||||
| 'answer_attention_mask': answer.attention_mask.squeeze(), | |||||
| } | |||||
| @@ -0,0 +1,128 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from PIL import Image | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.multi_modal import MPlugForAllTasks | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import EpochBasedTrainer, build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestFinetuneMPlug(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| datadict = MsDataset.load('coco_captions_small_slice') | |||||
| self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | |||||
| lambda _: { | |||||
| 'question': 'what the picture describes?' | |||||
| }).rename_column('image:FILE', | |||||
| 'image').rename_column('answer:Value', 'answer')) | |||||
| self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map( | |||||
| lambda _: { | |||||
| 'question': 'what the picture describes?' | |||||
| }).rename_column('image:FILE', | |||||
| 'image').rename_column('answer:Value', 'answer')) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_with_caption(self): | |||||
| kwargs = dict( | |||||
| model='damo/mplug_image-captioning_coco_base_en', | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(3): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_caption_with_model_and_args(self): | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| cache_path = snapshot_download( | |||||
| 'damo/mplug_image-captioning_coco_base_en') | |||||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_with_vqa(self): | |||||
| kwargs = dict( | |||||
| model='damo/mplug_visual-question-answering_coco_large_en', | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(3): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_vqa_with_model_and_args(self): | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| cache_path = snapshot_download( | |||||
| 'damo/mplug_visual-question-answering_coco_large_en') | |||||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||