支持 MPLUG 模型 image-text-retrieval 任务的 pipeline 和 finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9919955
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab | |||
| size 218143 | |||
| @@ -170,6 +170,7 @@ class Pipelines(object): | |||
| multi_modal_similarity = 'multi-modal-similarity' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| image_text_retrieval = 'image-text-retrieval' | |||
| class Trainers(object): | |||
| @@ -64,6 +64,10 @@ class MPlugConfig(PretrainedConfig): | |||
| clip_transformer_width=768, | |||
| clip_transformer_heads=12, | |||
| clip_transformer_layers=12, | |||
| # retrieval | |||
| queue_size=65536, | |||
| embed_dim=256, | |||
| temp=0.07, | |||
| **kwargs): | |||
| super().__init__(**kwargs) | |||
| @@ -99,6 +103,10 @@ class MPlugConfig(PretrainedConfig): | |||
| self.clip_transformer_width = clip_transformer_width | |||
| self.clip_transformer_heads = clip_transformer_heads | |||
| self.clip_transformer_layers = clip_transformer_layers | |||
| # retrieval | |||
| self.queue_size = queue_size | |||
| self.embed_dim = embed_dim | |||
| self.temp = temp | |||
| @classmethod | |||
| def from_yaml_file(cls, yaml_file: Union[str, | |||
| @@ -1855,7 +1855,8 @@ class MPlug(PreTrainedModel): | |||
| task_mapping = { | |||
| Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, | |||
| Tasks.image_captioning: MPLUGForImageCaption | |||
| Tasks.image_captioning: MPlugForImageCaption, | |||
| Tasks.image_text_retrieval: MPlugForImageTextRetrieval, | |||
| } | |||
| config = cls.config_class.from_yaml_file( | |||
| os.path.join(model_dir, CONFIG_NAME)) | |||
| @@ -1915,6 +1916,33 @@ class MPlug(PreTrainedModel): | |||
| clip_model.visual.positional_embedding = pos_embed | |||
| return clip_model | |||
| def init_distill(self, config): | |||
| self.distill = config.distill | |||
| if self.distill: | |||
| self.visual_encoder_m = self._initialize_clip(config) | |||
| self.text_encoder_m = BertModel( | |||
| self.config_encoder, add_pooling_layer=False) | |||
| self.fusion_encoder_m = FusionModel( | |||
| self.config_fusion, add_pooling_layer=False) | |||
| self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||
| self.model_pairs = [ | |||
| [self.visual_encoder, self.visual_encoder_m], | |||
| [self.text_encoder, self.text_encoder_m], | |||
| [self.text_decoder, self.text_decoder_m], | |||
| ] | |||
| if self.config_encoder.hidden_size != config.vision_width: | |||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||
| self.config_encoder.hidden_size) | |||
| self.visn_layer_norm_m = nn.LayerNorm( | |||
| self.config_encoder.hidden_size, eps=1e-12) | |||
| self.dropout_m = nn.Dropout( | |||
| self.config_encoder.hidden_dropout_prob) | |||
| self.model_pairs.extend( | |||
| [[self.visn_fc, self.visn_fc_m], | |||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||
| self.copy_params() | |||
| self.momentum = 0.995 | |||
| def forward(self, *args, **kwargs): | |||
| raise NotImplementedError | |||
| @@ -1978,33 +2006,6 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||
| self.beam_generator = TextGenerator(config, self.text_decoder) | |||
| self.init_distill(config) | |||
| def init_distill(self, config): | |||
| self.distill = config.distill | |||
| if self.distill: | |||
| self.visual_encoder_m = self._initialize_clip(config) | |||
| self.text_encoder_m = BertModel( | |||
| self.config_encoder, add_pooling_layer=False) | |||
| self.fusion_encoder_m = FusionModel( | |||
| self.config_fusion, add_pooling_layer=False) | |||
| self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||
| self.model_pairs = [ | |||
| [self.visual_encoder, self.visual_encoder_m], | |||
| [self.text_encoder, self.text_encoder_m], | |||
| [self.text_decoder, self.text_decoder_m], | |||
| ] | |||
| if self.config_encoder.hidden_size != config.vision_width: | |||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||
| self.config_encoder.hidden_size) | |||
| self.visn_layer_norm_m = nn.LayerNorm( | |||
| self.config_encoder.hidden_size, eps=1e-12) | |||
| self.dropout_m = nn.Dropout( | |||
| self.config_encoder.hidden_dropout_prob) | |||
| self.model_pairs.extend( | |||
| [[self.visn_fc, self.visn_fc_m], | |||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||
| self.copy_params() | |||
| self.momentum = 0.995 | |||
| def forward(self, | |||
| image, | |||
| question, | |||
| @@ -2142,7 +2143,7 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||
| return topk_ids, topk_probs | |||
| class MPLUGForImageCaption(MPlug): | |||
| class MPlugForImageCaption(MPlug): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| @@ -2215,3 +2216,264 @@ class MPLUGForImageCaption(MPlug): | |||
| else: | |||
| topk_ids, topk_probs = self.generation(image_embeds, image_atts) | |||
| return topk_ids, topk_probs | |||
| class MPlugForImageTextRetrieval(MPlug): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.embed_dim = config.embed_dim | |||
| self.temp = nn.Parameter(torch.ones([]) * config.temp) | |||
| self.queue_size = config.queue_size | |||
| self.momentum = config.momentum | |||
| self.alpha = config.alpha | |||
| self.queue_size = config.queue_size | |||
| self.text_width = self.config_encoder.hidden_size | |||
| self.embed_dim = config.embed_dim | |||
| self.vision_proj = nn.Linear(self.text_width, self.embed_dim) | |||
| self.text_proj = nn.Linear(self.text_width, self.embed_dim) | |||
| self.itm_head = nn.Linear(self.text_width, 2) | |||
| self.register_buffer('image_queue', | |||
| torch.randn(self.embed_dim, self.queue_size)) | |||
| self.register_buffer('text_queue', | |||
| torch.randn(self.embed_dim, self.queue_size)) | |||
| self.register_buffer('idx_queue', torch.full((1, self.queue_size), | |||
| -100)) | |||
| self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) | |||
| self.image_queue = F.normalize(self.image_queue, dim=0) | |||
| self.text_queue = F.normalize(self.text_queue, dim=0) | |||
| self.init_distill(config) | |||
| def init_distill(self, config): | |||
| self.distill = config.distill | |||
| if self.distill: | |||
| self.visual_encoder_m = self._initialize_clip(config) | |||
| self.text_encoder_m = BertModel( | |||
| self.config_encoder, add_pooling_layer=False) | |||
| self.fusion_encoder_m = FusionModel( | |||
| self.config_fusion, add_pooling_layer=False) | |||
| self.vision_proj_m = nn.Linear(self.text_width, self.embed_dim) | |||
| self.text_proj_m = nn.Linear(self.text_width, self.embed_dim) | |||
| self.model_pairs = [ | |||
| [self.visual_encoder, self.visual_encoder_m], | |||
| [self.text_encoder, self.text_encoder_m], | |||
| [self.text_proj, self.text_proj_m], | |||
| [self.vision_proj, self.vision_proj_m], | |||
| ] | |||
| if self.config_encoder.hidden_size != config.vision_width: | |||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||
| self.config_encoder.hidden_size) | |||
| self.visn_layer_norm_m = nn.LayerNorm( | |||
| self.config_encoder.hidden_size, eps=1e-12) | |||
| self.dropout_m = nn.Dropout( | |||
| self.config_encoder.hidden_dropout_prob) | |||
| self.model_pairs.extend( | |||
| [[self.visn_fc, self.visn_fc_m], | |||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||
| self.copy_params() | |||
| self.momentum = 0.995 | |||
| @torch.no_grad() | |||
| def _dequeue_and_enqueue(self, image_feat, text_feat, idx): | |||
| def concat_all_gather(tensor): | |||
| """ | |||
| Performs all_gather operation on the provided tensors. | |||
| *** Warning ***: torch.distributed.all_gather has no gradient. | |||
| """ | |||
| if not torch.distributed.is_initialized(): | |||
| return tensor | |||
| tensors_gather = [ | |||
| torch.ones_like(tensor) | |||
| for _ in range(torch.distributed.get_world_size()) | |||
| ] | |||
| torch.distributed.all_gather( | |||
| tensors_gather, tensor, async_op=False) | |||
| output = torch.cat(tensors_gather, dim=0) | |||
| return output | |||
| # gather keys before updating queue | |||
| image_feats = concat_all_gather(image_feat) | |||
| text_feats = concat_all_gather(text_feat) | |||
| idxs = concat_all_gather(idx) | |||
| batch_size = image_feats.shape[0] | |||
| ptr = int(self.queue_ptr) | |||
| # assert self.queue_size % batch_size == 0 # for simplicity | |||
| # replace the keys at ptr (dequeue and enqueue) | |||
| self.image_queue[:, ptr:ptr + batch_size] = image_feats.T | |||
| self.text_queue[:, ptr:ptr + batch_size] = text_feats.T | |||
| self.idx_queue[:, ptr:ptr + batch_size] = idxs.T | |||
| ptr = (ptr + batch_size) % self.queue_size # move pointer | |||
| self.queue_ptr[0] = ptr | |||
| def forward(self, image, text, idx=None, train=True): | |||
| if train: | |||
| image_embeds = self.visual_encoder.visual( | |||
| image, skip_last_layer=True) | |||
| if self.large: | |||
| image_embeds = self.dropout( | |||
| self.visn_layer_norm(self.visn_fc(image_embeds))) | |||
| image_atts = torch.ones( | |||
| image_embeds.size()[:-1], dtype=torch.long).to(image.device) | |||
| image_feat = F.normalize( | |||
| self.vision_proj(image_embeds[:, 0, :]), dim=-1) | |||
| text_output = self.text_encoder( | |||
| text.input_ids, | |||
| attention_mask=text.attention_mask, | |||
| return_dict=True) | |||
| text_embeds = text_output.last_hidden_state | |||
| text_feat = F.normalize( | |||
| self.text_proj(text_embeds[:, 0, :]), dim=-1) | |||
| idx = idx.view(-1, 1) | |||
| idx_all = torch.cat( | |||
| [idx.t(), self.idx_queue.clone().detach()], dim=1) | |||
| pos_idx = torch.eq(idx, idx_all).float() | |||
| sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) | |||
| with torch.no_grad(): | |||
| self._momentum_update() | |||
| image_embeds_m = self.visual_encoder_m.visual( | |||
| image, skip_last_layer=True) | |||
| if self.large: | |||
| image_embeds_m = self.dropout_m( | |||
| self.visn_layer_norm_m(self.visn_fc_m(image_embeds_m))) | |||
| image_feat_m = F.normalize( | |||
| self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1) | |||
| image_feat_all = torch.cat( | |||
| [image_feat_m.t(), | |||
| self.image_queue.clone().detach()], | |||
| dim=1) | |||
| text_output_m = self.text_encoder_m( | |||
| text.input_ids, | |||
| attention_mask=text.attention_mask, | |||
| return_dict=True) | |||
| text_feat_m = F.normalize( | |||
| self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), | |||
| dim=-1) | |||
| text_feat_all = torch.cat( | |||
| [text_feat_m.t(), | |||
| self.text_queue.clone().detach()], dim=1) | |||
| if self.distill: | |||
| sim_i2t_m = image_feat_m @ text_feat_all / self.temp | |||
| sim_t2i_m = text_feat_m @ image_feat_all / self.temp | |||
| sim_i2t_targets = self.alpha * F.softmax( | |||
| sim_i2t_m, dim=1) + (1 - self.alpha) * sim_targets | |||
| sim_t2i_targets = self.alpha * F.softmax( | |||
| sim_t2i_m, dim=1) + (1 - self.alpha) * sim_targets | |||
| sim_i2t = image_feat @ text_feat_all / self.temp | |||
| sim_t2i = text_feat @ image_feat_all / self.temp | |||
| if self.distill: | |||
| loss_i2t = -torch.sum( | |||
| F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, | |||
| dim=1).mean() | |||
| loss_t2i = -torch.sum( | |||
| F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, | |||
| dim=1).mean() | |||
| else: | |||
| loss_i2t = -torch.sum( | |||
| F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() | |||
| loss_t2i = -torch.sum( | |||
| F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() | |||
| loss_ita = (loss_i2t + loss_t2i) / 2 | |||
| self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) | |||
| # forward the positve image-text pair | |||
| _, output_pos = self.fusion_encoder( | |||
| encoder_embeds=text_embeds, | |||
| attention_mask=text.attention_mask, | |||
| encoder_hidden_states=image_embeds, | |||
| encoder_attention_mask=image_atts, | |||
| return_dict=False, | |||
| ) | |||
| with torch.no_grad(): | |||
| bs = image.size(0) | |||
| weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) | |||
| weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) | |||
| mask = torch.eq(idx, idx.T) | |||
| weights_i2t.masked_fill_(mask, 0) | |||
| weights_t2i.masked_fill_(mask, 0) | |||
| # select a negative image for each text | |||
| image_embeds_neg = [] | |||
| for b in range(bs): | |||
| neg_idx = torch.multinomial(weights_t2i[b], 1).item() | |||
| image_embeds_neg.append(image_embeds[neg_idx]) | |||
| image_embeds_neg = torch.stack(image_embeds_neg, dim=0) | |||
| # select a negative text for each image | |||
| text_embeds_neg = [] | |||
| text_atts_neg = [] | |||
| for b in range(bs): | |||
| neg_idx = torch.multinomial(weights_i2t[b], 1).item() | |||
| text_embeds_neg.append(text_embeds[neg_idx]) | |||
| text_atts_neg.append(text.attention_mask[neg_idx]) | |||
| text_embeds_neg = torch.stack(text_embeds_neg, dim=0) | |||
| text_atts_neg = torch.stack(text_atts_neg, dim=0) | |||
| text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) | |||
| text_atts_all = torch.cat([text.attention_mask, text_atts_neg], | |||
| dim=0) | |||
| image_embeds_all = torch.cat([image_embeds_neg, image_embeds], | |||
| dim=0) | |||
| image_atts_all = torch.cat([image_atts, image_atts], dim=0) | |||
| _, output_neg = self.fusion_encoder( | |||
| encoder_embeds=text_embeds_all, | |||
| attention_mask=text_atts_all, | |||
| encoder_hidden_states=image_embeds_all, | |||
| encoder_attention_mask=image_atts_all, | |||
| return_dict=False, | |||
| ) | |||
| vl_embeddings = torch.cat( | |||
| [output_pos[:, 0, :], output_neg[:, 0, :]], dim=0) | |||
| vl_output = self.itm_head(vl_embeddings) | |||
| ones_tmp = torch.ones(bs, dtype=torch.long) | |||
| zeros_tmp = torch.zeros(2 * bs, dtype=torch.long) | |||
| itm_labels = torch.cat([ones_tmp, zeros_tmp], | |||
| dim=0).to(image.device) | |||
| loss_itm = F.cross_entropy(vl_output, itm_labels) | |||
| return loss_ita + loss_itm | |||
| else: | |||
| text_output = self.text_encoder( | |||
| text.input_ids, attention_mask=text.attention_mask) | |||
| text_feat = text_output.last_hidden_state | |||
| image_feat = self.visual_encoder.visual( | |||
| image, skip_last_layer=True) | |||
| image_feat = self.visn_layer_norm(self.visn_fc(image_feat)) | |||
| image_att = torch.ones( | |||
| image_feat.size()[:-1], | |||
| dtype=torch.long, | |||
| device=image_feat.device) | |||
| _, output = self.fusion_encoder( | |||
| encoder_embeds=text_feat, | |||
| attention_mask=text.attention_mask, | |||
| encoder_hidden_states=image_feat, | |||
| encoder_attention_mask=image_att, | |||
| return_dict=False, | |||
| ) | |||
| scores = self.itm_head(output[:, 0, :]) | |||
| scores = F.softmax(scores, dim=-1) | |||
| return scores | |||
| @@ -12,6 +12,7 @@ __all__ = ['MPlugForAllTasks'] | |||
| @MODELS.register_module( | |||
| Tasks.visual_question_answering, module_name=Models.mplug) | |||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) | |||
| @MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug) | |||
| class MPlugForAllTasks(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -43,39 +44,50 @@ class MPlugForAllTasks(TorchModel): | |||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
| if not self.training and 'answer_input_ids' not in input: | |||
| topk_ids, _ = self.model(**input) | |||
| # inference | |||
| if not self.training and 'question' in input: | |||
| output = self.model(input['image'], input['question'], train=False) | |||
| if not isinstance(output, tuple): | |||
| return output | |||
| topk_ids, _ = output | |||
| pred_string: str = self.tokenizer.decode(topk_ids[0][0]) | |||
| for _old, _new in replace_tokens_bert: | |||
| pred_string = pred_string.replace(_old, _new) | |||
| pred_string = pred_string.strip() | |||
| return pred_string | |||
| else: | |||
| import addict | |||
| # train and evaluate | |||
| import addict | |||
| image = input['image'] | |||
| answer = addict.Dict( | |||
| input_ids=input['answer_input_ids'], | |||
| attention_mask=input['answer_attention_mask']) | |||
| if 'index' not in input: | |||
| question = addict.Dict( | |||
| input_ids=input['question_input_ids'], | |||
| attention_mask=input['question_attention_mask']) | |||
| answer = addict.Dict( | |||
| input_ids=input['answer_input_ids'], | |||
| attention_mask=input['answer_attention_mask']) | |||
| output = self.model( | |||
| input['image'], question, answer, train=self.training) | |||
| if self.training: | |||
| return {'loss': output} | |||
| topk_ids, _ = output | |||
| preds: List[str] = [ | |||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||
| ] | |||
| for i in range(len(preds)): | |||
| for _old, _new in replace_tokens_bert: | |||
| preds[i] = preds[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| tgts: List[str] = [ | |||
| self.tokenizer.decode(batch) | |||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||
| ] | |||
| for i in range(len(tgts)): | |||
| for _old, _new in replace_tokens_bert: | |||
| tgts[i] = tgts[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| return {'preds': preds, 'tgts': tgts} | |||
| output = self.model(image, question, answer, train=self.training) | |||
| else: | |||
| index = input['index'] | |||
| output = self.model(image, answer, index, train=self.training) | |||
| if self.training: | |||
| return {'loss': output} | |||
| # evaluate | |||
| topk_ids, _ = output | |||
| preds: List[str] = [ | |||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||
| ] | |||
| for i in range(len(preds)): | |||
| for _old, _new in replace_tokens_bert: | |||
| preds[i] = preds[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| tgts: List[str] = [ | |||
| self.tokenizer.decode(batch) | |||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||
| ] | |||
| for i in range(len(tgts)): | |||
| for _old, _new in replace_tokens_bert: | |||
| tgts[i] = tgts[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| return {'preds': preds, 'tgts': tgts} | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Model, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import MPlugPreprocessor, Preprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_text_retrieval, module_name=Pipelines.image_text_retrieval) | |||
| class ImageTextRetrievalPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a | |||
| image text retrieval pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||
| f'model must be a single str or Model, but got {type(model)}' | |||
| if isinstance(model, str): | |||
| pipe_model = Model.from_pretrained(model) | |||
| elif isinstance(model, Model): | |||
| pipe_model = model | |||
| else: | |||
| raise NotImplementedError | |||
| pipe_model.model.eval() | |||
| if preprocessor is None: | |||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return {OutputKeys.SCORES: inputs[0].tolist()} | |||
| @@ -1,6 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Union | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import torch | |||
| from PIL import Image | |||
| @@ -104,6 +104,7 @@ class MPlugPreprocessor(Preprocessor): | |||
| self._tokenizer = None | |||
| self._patch_resize_transform = None | |||
| self._image_map = {} | |||
| @property | |||
| def tokenizer(self): | |||
| @@ -133,31 +134,31 @@ class MPlugPreprocessor(Preprocessor): | |||
| ]) | |||
| return self._patch_resize_transform | |||
| def __call__(self, *args, **kwargs): | |||
| call_mapping = { | |||
| Tasks.visual_question_answering: self.image_text_call, | |||
| Tasks.image_captioning: self.image_text_call, | |||
| } | |||
| def image_open(self, path: str) -> Tuple[Image.Image, int]: | |||
| if path not in self._image_map: | |||
| index = len(self._image_map) | |||
| self._image_map[path] = (Image.open(path), index) | |||
| return self._image_map[path] | |||
| def __call__( | |||
| self, data: Union[Image.Image, tuple, | |||
| Dict[str, Any]]) -> Dict[str, Any]: | |||
| self.cfg = Config.from_file( | |||
| osp.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| return call_mapping[self.cfg.task](*args, **kwargs) | |||
| def image_text_call( | |||
| self, data: Union[Image.Image, tuple, | |||
| Dict[str, Any]]) -> Dict[str, Any]: | |||
| if isinstance(data, (Image.Image, str)): | |||
| image = data | |||
| elif isinstance(data, tuple): | |||
| image = data[0] | |||
| else: | |||
| image = data['image'] | |||
| index = 0 | |||
| if isinstance(image, str): | |||
| image = Image.open(image) | |||
| question = '' if self.cfg.task != Tasks.visual_question_answering \ | |||
| else data[1 if isinstance(data, tuple) else 'question'] | |||
| image, index = self.image_open(image) | |||
| image = image.convert('RGB') | |||
| image = self.patch_resize_transform(image) | |||
| question = '' if self.cfg.task == Tasks.image_captioning \ | |||
| else data[1 if isinstance(data, tuple) else 'question'] | |||
| question = self.tokenizer( | |||
| question.lower(), | |||
| padding='max_length', | |||
| @@ -167,7 +168,7 @@ class MPlugPreprocessor(Preprocessor): | |||
| if self.mode == ModeKeys.INFERENCE: | |||
| image = torch.stack([image], dim=0) | |||
| return {'image': image, 'question': question, 'train': False} | |||
| return {'image': image, 'question': question} | |||
| else: | |||
| answer = data['answer'] | |||
| answer = self.tokenizer( | |||
| @@ -176,10 +177,13 @@ class MPlugPreprocessor(Preprocessor): | |||
| truncation=True, | |||
| max_length=self.tokenizer_max_length, | |||
| return_tensors='pt') | |||
| return { | |||
| output = { | |||
| 'image': image, | |||
| 'question_input_ids': question.input_ids.squeeze(), | |||
| 'question_attention_mask': question.attention_mask.squeeze(), | |||
| 'answer_input_ids': answer.input_ids.squeeze(), | |||
| 'answer_attention_mask': answer.attention_mask.squeeze(), | |||
| } | |||
| if self.cfg.task == Tasks.image_text_retrieval: | |||
| output['index'] = index | |||
| return output | |||
| @@ -121,6 +121,7 @@ class MultiModalTasks(object): | |||
| visual_question_answering = 'visual-question-answering' | |||
| visual_entailment = 'visual-entailment' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| image_text_retrieval = 'image-text-retrieval' | |||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||
| @@ -54,6 +54,27 @@ class MplugTasksTest(unittest.TestCase): | |||
| result = pipeline_vqa(input) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_image_text_retrieval_with_model(self): | |||
| model = Model.from_pretrained( | |||
| 'damo/mplug_image-text-retrieval_flickr30k_large_en') | |||
| pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) | |||
| image = Image.open('data/test/images/image-text-retrieval.jpg') | |||
| question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' | |||
| input = {'image': image, 'question': question} | |||
| result = pipeline_retrieval(input) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_image_text_retrieval_with_name(self): | |||
| model = 'damo/mplug_image-text-retrieval_flickr30k_large_en' | |||
| pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) | |||
| image = Image.open('data/test/images/image-text-retrieval.jpg') | |||
| question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' | |||
| input = {'image': image, 'question': question} | |||
| result = pipeline_retrieval(input) | |||
| print(result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -4,8 +4,6 @@ import shutil | |||
| import tempfile | |||
| import unittest | |||
| from PIL import Image | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.multi_modal import MPlugForAllTasks | |||
| @@ -23,7 +21,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| datadict = MsDataset.load('coco_captions_small_slice') | |||
| from modelscope.utils.constant import DownloadMode | |||
| datadict = MsDataset.load( | |||
| 'coco_captions_small_slice', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
| self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | |||
| lambda _: { | |||
| 'question': 'what the picture describes?' | |||
| @@ -35,17 +36,19 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| }).rename_column('image:FILE', | |||
| 'image').rename_column('answer:Value', 'answer')) | |||
| self.max_epochs = 3 | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_with_caption(self): | |||
| kwargs = dict( | |||
| model='damo/mplug_image-captioning_coco_base_en', | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| @@ -53,15 +56,11 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(3): | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_caption_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download( | |||
| 'damo/mplug_image-captioning_coco_base_en') | |||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||
| @@ -70,7 +69,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| model=model, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=2, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| @@ -78,16 +77,16 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_with_vqa(self): | |||
| kwargs = dict( | |||
| model='damo/mplug_visual-question-answering_coco_large_en', | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| @@ -95,15 +94,11 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(3): | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_vqa_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download( | |||
| 'damo/mplug_visual-question-answering_coco_large_en') | |||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||
| @@ -112,7 +107,45 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| model=model, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=2, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_with_retrieval(self): | |||
| kwargs = dict( | |||
| model='damo/mplug_image-text-retrieval_flickr30k_large_en', | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_retrieval_with_model_and_args(self): | |||
| cache_path = snapshot_download( | |||
| 'damo/mplug_image-text-retrieval_flickr30k_large_en') | |||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| @@ -120,7 +153,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||