diff --git a/data/test/images/image-text-retrieval.jpg b/data/test/images/image-text-retrieval.jpg new file mode 100644 index 00000000..2d20374a --- /dev/null +++ b/data/test/images/image-text-retrieval.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab +size 218143 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index eab870ae..b4d005a7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -170,6 +170,7 @@ class Pipelines(object): multi_modal_similarity = 'multi-modal-similarity' text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' + image_text_retrieval = 'image-text-retrieval' class Trainers(object): diff --git a/modelscope/models/multi_modal/mplug/configuration_mplug.py b/modelscope/models/multi_modal/mplug/configuration_mplug.py index c275ed15..914678c5 100644 --- a/modelscope/models/multi_modal/mplug/configuration_mplug.py +++ b/modelscope/models/multi_modal/mplug/configuration_mplug.py @@ -64,6 +64,10 @@ class MPlugConfig(PretrainedConfig): clip_transformer_width=768, clip_transformer_heads=12, clip_transformer_layers=12, + # retrieval + queue_size=65536, + embed_dim=256, + temp=0.07, **kwargs): super().__init__(**kwargs) @@ -99,6 +103,10 @@ class MPlugConfig(PretrainedConfig): self.clip_transformer_width = clip_transformer_width self.clip_transformer_heads = clip_transformer_heads self.clip_transformer_layers = clip_transformer_layers + # retrieval + self.queue_size = queue_size + self.embed_dim = embed_dim + self.temp = temp @classmethod def from_yaml_file(cls, yaml_file: Union[str, diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index 6311bd31..78f60f9b 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1855,7 +1855,8 @@ class MPlug(PreTrainedModel): task_mapping = { Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, - Tasks.image_captioning: MPLUGForImageCaption + Tasks.image_captioning: MPlugForImageCaption, + Tasks.image_text_retrieval: MPlugForImageTextRetrieval, } config = cls.config_class.from_yaml_file( os.path.join(model_dir, CONFIG_NAME)) @@ -1915,6 +1916,33 @@ class MPlug(PreTrainedModel): clip_model.visual.positional_embedding = pos_embed return clip_model + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder_m = BertLMHeadModel(self.config_decoder) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + def forward(self, *args, **kwargs): raise NotImplementedError @@ -1978,33 +2006,6 @@ class MPlugForVisualQuestionAnswering(MPlug): self.beam_generator = TextGenerator(config, self.text_decoder) self.init_distill(config) - def init_distill(self, config): - self.distill = config.distill - if self.distill: - self.visual_encoder_m = self._initialize_clip(config) - self.text_encoder_m = BertModel( - self.config_encoder, add_pooling_layer=False) - self.fusion_encoder_m = FusionModel( - self.config_fusion, add_pooling_layer=False) - self.text_decoder_m = BertLMHeadModel(self.config_decoder) - self.model_pairs = [ - [self.visual_encoder, self.visual_encoder_m], - [self.text_encoder, self.text_encoder_m], - [self.text_decoder, self.text_decoder_m], - ] - if self.config_encoder.hidden_size != config.vision_width: - self.visn_fc_m = nn.Linear(config.vision_width, - self.config_encoder.hidden_size) - self.visn_layer_norm_m = nn.LayerNorm( - self.config_encoder.hidden_size, eps=1e-12) - self.dropout_m = nn.Dropout( - self.config_encoder.hidden_dropout_prob) - self.model_pairs.extend( - [[self.visn_fc, self.visn_fc_m], - [self.visn_layer_norm, self.visn_layer_norm_m]]) - self.copy_params() - self.momentum = 0.995 - def forward(self, image, question, @@ -2142,7 +2143,7 @@ class MPlugForVisualQuestionAnswering(MPlug): return topk_ids, topk_probs -class MPLUGForImageCaption(MPlug): +class MPlugForImageCaption(MPlug): def __init__(self, config): super().__init__(config) @@ -2215,3 +2216,264 @@ class MPLUGForImageCaption(MPlug): else: topk_ids, topk_probs = self.generation(image_embeds, image_atts) return topk_ids, topk_probs + + +class MPlugForImageTextRetrieval(MPlug): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.embed_dim + self.temp = nn.Parameter(torch.ones([]) * config.temp) + self.queue_size = config.queue_size + self.momentum = config.momentum + self.alpha = config.alpha + + self.queue_size = config.queue_size + self.text_width = self.config_encoder.hidden_size + self.embed_dim = config.embed_dim + + self.vision_proj = nn.Linear(self.text_width, self.embed_dim) + self.text_proj = nn.Linear(self.text_width, self.embed_dim) + self.itm_head = nn.Linear(self.text_width, 2) + + self.register_buffer('image_queue', + torch.randn(self.embed_dim, self.queue_size)) + self.register_buffer('text_queue', + torch.randn(self.embed_dim, self.queue_size)) + self.register_buffer('idx_queue', torch.full((1, self.queue_size), + -100)) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + self.image_queue = F.normalize(self.image_queue, dim=0) + self.text_queue = F.normalize(self.text_queue, dim=0) + self.init_distill(config) + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.vision_proj_m = nn.Linear(self.text_width, self.embed_dim) + self.text_proj_m = nn.Linear(self.text_width, self.embed_dim) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_proj, self.text_proj_m], + [self.vision_proj, self.vision_proj_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idx): + + def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + if not torch.distributed.is_initialized(): + return tensor + tensors_gather = [ + torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather( + tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + idxs = concat_all_gather(idx) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + # assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.queue_ptr[0] = ptr + + def forward(self, image, text, idx=None, train=True): + if train: + image_embeds = self.visual_encoder.visual( + image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + image_feat = F.normalize( + self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_proj(text_embeds[:, 0, :]), dim=-1) + + idx = idx.view(-1, 1) + idx_all = torch.cat( + [idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m.visual( + image, skip_last_layer=True) + if self.large: + image_embeds_m = self.dropout_m( + self.visn_layer_norm_m(self.visn_fc_m(image_embeds_m))) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1) + image_feat_all = torch.cat( + [image_feat_m.t(), + self.image_queue.clone().detach()], + dim=1) + text_output_m = self.text_encoder_m( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True) + text_feat_m = F.normalize( + self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), + dim=-1) + text_feat_all = torch.cat( + [text_feat_m.t(), + self.text_queue.clone().detach()], dim=1) + + if self.distill: + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_i2t_targets = self.alpha * F.softmax( + sim_i2t_m, dim=1) + (1 - self.alpha) * sim_targets + sim_t2i_targets = self.alpha * F.softmax( + sim_t2i_m, dim=1) + (1 - self.alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + if self.distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, + dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, + dim=1).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() + + loss_ita = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) + + # forward the positve image-text pair + _, output_pos = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + ) + with torch.no_grad(): + bs = image.size(0) + weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) + weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) + + mask = torch.eq(idx, idx.T) + weights_i2t.masked_fill_(mask, 0) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], + dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], + dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + _, output_neg = self.fusion_encoder( + encoder_embeds=text_embeds_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=False, + ) + + vl_embeddings = torch.cat( + [output_pos[:, 0, :], output_neg[:, 0, :]], dim=0) + vl_output = self.itm_head(vl_embeddings) + + ones_tmp = torch.ones(bs, dtype=torch.long) + zeros_tmp = torch.zeros(2 * bs, dtype=torch.long) + itm_labels = torch.cat([ones_tmp, zeros_tmp], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + return loss_ita + loss_itm + else: + text_output = self.text_encoder( + text.input_ids, attention_mask=text.attention_mask) + text_feat = text_output.last_hidden_state + image_feat = self.visual_encoder.visual( + image, skip_last_layer=True) + image_feat = self.visn_layer_norm(self.visn_fc(image_feat)) + image_att = torch.ones( + image_feat.size()[:-1], + dtype=torch.long, + device=image_feat.device) + _, output = self.fusion_encoder( + encoder_embeds=text_feat, + attention_mask=text.attention_mask, + encoder_hidden_states=image_feat, + encoder_attention_mask=image_att, + return_dict=False, + ) + scores = self.itm_head(output[:, 0, :]) + scores = F.softmax(scores, dim=-1) + + return scores diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index fb460714..608cc733 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -12,6 +12,7 @@ __all__ = ['MPlugForAllTasks'] @MODELS.register_module( Tasks.visual_question_answering, module_name=Models.mplug) @MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) +@MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug) class MPlugForAllTasks(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): @@ -43,39 +44,50 @@ class MPlugForAllTasks(TorchModel): ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) - if not self.training and 'answer_input_ids' not in input: - topk_ids, _ = self.model(**input) + # inference + if not self.training and 'question' in input: + output = self.model(input['image'], input['question'], train=False) + if not isinstance(output, tuple): + return output + topk_ids, _ = output pred_string: str = self.tokenizer.decode(topk_ids[0][0]) for _old, _new in replace_tokens_bert: pred_string = pred_string.replace(_old, _new) pred_string = pred_string.strip() return pred_string - else: - import addict + + # train and evaluate + import addict + image = input['image'] + answer = addict.Dict( + input_ids=input['answer_input_ids'], + attention_mask=input['answer_attention_mask']) + if 'index' not in input: question = addict.Dict( input_ids=input['question_input_ids'], attention_mask=input['question_attention_mask']) - answer = addict.Dict( - input_ids=input['answer_input_ids'], - attention_mask=input['answer_attention_mask']) - output = self.model( - input['image'], question, answer, train=self.training) - if self.training: - return {'loss': output} - topk_ids, _ = output - preds: List[str] = [ - self.tokenizer.decode(batch[0]) for batch in topk_ids - ] - for i in range(len(preds)): - for _old, _new in replace_tokens_bert: - preds[i] = preds[i].replace(_old, _new) - preds[i] = preds[i].strip() - tgts: List[str] = [ - self.tokenizer.decode(batch) - for batch in input['answer_input_ids'].cpu().numpy().tolist() - ] - for i in range(len(tgts)): - for _old, _new in replace_tokens_bert: - tgts[i] = tgts[i].replace(_old, _new) - preds[i] = preds[i].strip() - return {'preds': preds, 'tgts': tgts} + output = self.model(image, question, answer, train=self.training) + else: + index = input['index'] + output = self.model(image, answer, index, train=self.training) + if self.training: + return {'loss': output} + + # evaluate + topk_ids, _ = output + preds: List[str] = [ + self.tokenizer.decode(batch[0]) for batch in topk_ids + ] + for i in range(len(preds)): + for _old, _new in replace_tokens_bert: + preds[i] = preds[i].replace(_old, _new) + preds[i] = preds[i].strip() + tgts: List[str] = [ + self.tokenizer.decode(batch) + for batch in input['answer_input_ids'].cpu().numpy().tolist() + ] + for i in range(len(tgts)): + for _old, _new in replace_tokens_bert: + tgts[i] = tgts[i].replace(_old, _new) + preds[i] = preds[i].strip() + return {'preds': preds, 'tgts': tgts} diff --git a/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py new file mode 100644 index 00000000..1ebcf526 --- /dev/null +++ b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import MPlugPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_text_retrieval, module_name=Pipelines.image_text_retrieval) +class ImageTextRetrievalPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a + image text retrieval pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + f'model must be a single str or Model, but got {type(model)}' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + preprocessor = MPlugPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.SCORES: inputs[0].tolist()} diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 4f0cb977..9873a62c 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import torch from PIL import Image @@ -104,6 +104,7 @@ class MPlugPreprocessor(Preprocessor): self._tokenizer = None self._patch_resize_transform = None + self._image_map = {} @property def tokenizer(self): @@ -133,31 +134,31 @@ class MPlugPreprocessor(Preprocessor): ]) return self._patch_resize_transform - def __call__(self, *args, **kwargs): - call_mapping = { - Tasks.visual_question_answering: self.image_text_call, - Tasks.image_captioning: self.image_text_call, - } + def image_open(self, path: str) -> Tuple[Image.Image, int]: + if path not in self._image_map: + index = len(self._image_map) + self._image_map[path] = (Image.open(path), index) + return self._image_map[path] + def __call__( + self, data: Union[Image.Image, tuple, + Dict[str, Any]]) -> Dict[str, Any]: self.cfg = Config.from_file( osp.join(self.model_dir, ModelFile.CONFIGURATION)) - return call_mapping[self.cfg.task](*args, **kwargs) - def image_text_call( - self, data: Union[Image.Image, tuple, - Dict[str, Any]]) -> Dict[str, Any]: if isinstance(data, (Image.Image, str)): image = data elif isinstance(data, tuple): image = data[0] else: image = data['image'] + index = 0 if isinstance(image, str): - image = Image.open(image) - question = '' if self.cfg.task != Tasks.visual_question_answering \ - else data[1 if isinstance(data, tuple) else 'question'] + image, index = self.image_open(image) image = image.convert('RGB') image = self.patch_resize_transform(image) + question = '' if self.cfg.task == Tasks.image_captioning \ + else data[1 if isinstance(data, tuple) else 'question'] question = self.tokenizer( question.lower(), padding='max_length', @@ -167,7 +168,7 @@ class MPlugPreprocessor(Preprocessor): if self.mode == ModeKeys.INFERENCE: image = torch.stack([image], dim=0) - return {'image': image, 'question': question, 'train': False} + return {'image': image, 'question': question} else: answer = data['answer'] answer = self.tokenizer( @@ -176,10 +177,13 @@ class MPlugPreprocessor(Preprocessor): truncation=True, max_length=self.tokenizer_max_length, return_tensors='pt') - return { + output = { 'image': image, 'question_input_ids': question.input_ids.squeeze(), 'question_attention_mask': question.attention_mask.squeeze(), 'answer_input_ids': answer.input_ids.squeeze(), 'answer_attention_mask': answer.attention_mask.squeeze(), } + if self.cfg.task == Tasks.image_text_retrieval: + output['index'] = index + return output diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6d419a7e..66f734f9 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -121,6 +121,7 @@ class MultiModalTasks(object): visual_question_answering = 'visual-question-answering' visual_entailment = 'visual-entailment' video_multi_modal_embedding = 'video-multi-modal-embedding' + image_text_retrieval = 'image-text-retrieval' class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py index 4b8a813a..642ac11d 100644 --- a/tests/pipelines/test_mplug_tasks.py +++ b/tests/pipelines/test_mplug_tasks.py @@ -54,6 +54,27 @@ class MplugTasksTest(unittest.TestCase): result = pipeline_vqa(input) print(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_text_retrieval_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_image-text-retrieval_flickr30k_large_en') + pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) + image = Image.open('data/test/images/image-text-retrieval.jpg') + question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' + input = {'image': image, 'question': question} + result = pipeline_retrieval(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_text_retrieval_with_name(self): + model = 'damo/mplug_image-text-retrieval_flickr30k_large_en' + pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) + image = Image.open('data/test/images/image-text-retrieval.jpg') + question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' + input = {'image': image, 'question': question} + result = pipeline_retrieval(input) + print(result) + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/test_finetune_mplug.py b/tests/trainers/test_finetune_mplug.py index 5776141c..1298f1cd 100644 --- a/tests/trainers/test_finetune_mplug.py +++ b/tests/trainers/test_finetune_mplug.py @@ -4,8 +4,6 @@ import shutil import tempfile import unittest -from PIL import Image - from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers from modelscope.models.multi_modal import MPlugForAllTasks @@ -23,7 +21,10 @@ class TestFinetuneMPlug(unittest.TestCase): if not os.path.exists(self.tmp_dir): os.makedirs(self.tmp_dir) - datadict = MsDataset.load('coco_captions_small_slice') + from modelscope.utils.constant import DownloadMode + datadict = MsDataset.load( + 'coco_captions_small_slice', + download_mode=DownloadMode.FORCE_REDOWNLOAD) self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( lambda _: { 'question': 'what the picture describes?' @@ -35,17 +36,19 @@ class TestFinetuneMPlug(unittest.TestCase): }).rename_column('image:FILE', 'image').rename_column('answer:Value', 'answer')) + self.max_epochs = 3 + def tearDown(self): shutil.rmtree(self.tmp_dir) super().tearDown() @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_caption(self): - kwargs = dict( model='damo/mplug_image-captioning_coco_base_en', train_dataset=self.train_dataset, eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer: EpochBasedTrainer = build_trainer( @@ -53,15 +56,11 @@ class TestFinetuneMPlug(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(3): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_caption_with_model_and_args(self): - tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) - cache_path = snapshot_download( 'damo/mplug_image-captioning_coco_base_en') model = MPlugForAllTasks.from_pretrained(cache_path) @@ -70,7 +69,7 @@ class TestFinetuneMPlug(unittest.TestCase): model=model, train_dataset=self.train_dataset, eval_dataset=self.test_dataset, - max_epochs=2, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer: EpochBasedTrainer = build_trainer( @@ -78,16 +77,16 @@ class TestFinetuneMPlug(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(2): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_vqa(self): - kwargs = dict( model='damo/mplug_visual-question-answering_coco_large_en', train_dataset=self.train_dataset, eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer: EpochBasedTrainer = build_trainer( @@ -95,15 +94,11 @@ class TestFinetuneMPlug(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(3): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_vqa_with_model_and_args(self): - tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) - cache_path = snapshot_download( 'damo/mplug_visual-question-answering_coco_large_en') model = MPlugForAllTasks.from_pretrained(cache_path) @@ -112,7 +107,45 @@ class TestFinetuneMPlug(unittest.TestCase): model=model, train_dataset=self.train_dataset, eval_dataset=self.test_dataset, - max_epochs=2, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_retrieval(self): + kwargs = dict( + model='damo/mplug_image-text-retrieval_flickr30k_large_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_retrieval_with_model_and_args(self): + cache_path = snapshot_download( + 'damo/mplug_image-text-retrieval_flickr30k_large_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer: EpochBasedTrainer = build_trainer( @@ -120,7 +153,7 @@ class TestFinetuneMPlug(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(2): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files)