支持 MPLUG 模型 image-text-retrieval 任务的 pipeline 和 finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9919955
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab | |||||
| size 218143 | |||||
| @@ -170,6 +170,7 @@ class Pipelines(object): | |||||
| multi_modal_similarity = 'multi-modal-similarity' | multi_modal_similarity = 'multi-modal-similarity' | ||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -64,6 +64,10 @@ class MPlugConfig(PretrainedConfig): | |||||
| clip_transformer_width=768, | clip_transformer_width=768, | ||||
| clip_transformer_heads=12, | clip_transformer_heads=12, | ||||
| clip_transformer_layers=12, | clip_transformer_layers=12, | ||||
| # retrieval | |||||
| queue_size=65536, | |||||
| embed_dim=256, | |||||
| temp=0.07, | |||||
| **kwargs): | **kwargs): | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| @@ -99,6 +103,10 @@ class MPlugConfig(PretrainedConfig): | |||||
| self.clip_transformer_width = clip_transformer_width | self.clip_transformer_width = clip_transformer_width | ||||
| self.clip_transformer_heads = clip_transformer_heads | self.clip_transformer_heads = clip_transformer_heads | ||||
| self.clip_transformer_layers = clip_transformer_layers | self.clip_transformer_layers = clip_transformer_layers | ||||
| # retrieval | |||||
| self.queue_size = queue_size | |||||
| self.embed_dim = embed_dim | |||||
| self.temp = temp | |||||
| @classmethod | @classmethod | ||||
| def from_yaml_file(cls, yaml_file: Union[str, | def from_yaml_file(cls, yaml_file: Union[str, | ||||
| @@ -1855,7 +1855,8 @@ class MPlug(PreTrainedModel): | |||||
| task_mapping = { | task_mapping = { | ||||
| Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, | Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, | ||||
| Tasks.image_captioning: MPLUGForImageCaption | |||||
| Tasks.image_captioning: MPlugForImageCaption, | |||||
| Tasks.image_text_retrieval: MPlugForImageTextRetrieval, | |||||
| } | } | ||||
| config = cls.config_class.from_yaml_file( | config = cls.config_class.from_yaml_file( | ||||
| os.path.join(model_dir, CONFIG_NAME)) | os.path.join(model_dir, CONFIG_NAME)) | ||||
| @@ -1915,6 +1916,33 @@ class MPlug(PreTrainedModel): | |||||
| clip_model.visual.positional_embedding = pos_embed | clip_model.visual.positional_embedding = pos_embed | ||||
| return clip_model | return clip_model | ||||
| def init_distill(self, config): | |||||
| self.distill = config.distill | |||||
| if self.distill: | |||||
| self.visual_encoder_m = self._initialize_clip(config) | |||||
| self.text_encoder_m = BertModel( | |||||
| self.config_encoder, add_pooling_layer=False) | |||||
| self.fusion_encoder_m = FusionModel( | |||||
| self.config_fusion, add_pooling_layer=False) | |||||
| self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||||
| self.model_pairs = [ | |||||
| [self.visual_encoder, self.visual_encoder_m], | |||||
| [self.text_encoder, self.text_encoder_m], | |||||
| [self.text_decoder, self.text_decoder_m], | |||||
| ] | |||||
| if self.config_encoder.hidden_size != config.vision_width: | |||||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||||
| self.config_encoder.hidden_size) | |||||
| self.visn_layer_norm_m = nn.LayerNorm( | |||||
| self.config_encoder.hidden_size, eps=1e-12) | |||||
| self.dropout_m = nn.Dropout( | |||||
| self.config_encoder.hidden_dropout_prob) | |||||
| self.model_pairs.extend( | |||||
| [[self.visn_fc, self.visn_fc_m], | |||||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||||
| self.copy_params() | |||||
| self.momentum = 0.995 | |||||
| def forward(self, *args, **kwargs): | def forward(self, *args, **kwargs): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -1978,33 +2006,6 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||||
| self.beam_generator = TextGenerator(config, self.text_decoder) | self.beam_generator = TextGenerator(config, self.text_decoder) | ||||
| self.init_distill(config) | self.init_distill(config) | ||||
| def init_distill(self, config): | |||||
| self.distill = config.distill | |||||
| if self.distill: | |||||
| self.visual_encoder_m = self._initialize_clip(config) | |||||
| self.text_encoder_m = BertModel( | |||||
| self.config_encoder, add_pooling_layer=False) | |||||
| self.fusion_encoder_m = FusionModel( | |||||
| self.config_fusion, add_pooling_layer=False) | |||||
| self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||||
| self.model_pairs = [ | |||||
| [self.visual_encoder, self.visual_encoder_m], | |||||
| [self.text_encoder, self.text_encoder_m], | |||||
| [self.text_decoder, self.text_decoder_m], | |||||
| ] | |||||
| if self.config_encoder.hidden_size != config.vision_width: | |||||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||||
| self.config_encoder.hidden_size) | |||||
| self.visn_layer_norm_m = nn.LayerNorm( | |||||
| self.config_encoder.hidden_size, eps=1e-12) | |||||
| self.dropout_m = nn.Dropout( | |||||
| self.config_encoder.hidden_dropout_prob) | |||||
| self.model_pairs.extend( | |||||
| [[self.visn_fc, self.visn_fc_m], | |||||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||||
| self.copy_params() | |||||
| self.momentum = 0.995 | |||||
| def forward(self, | def forward(self, | ||||
| image, | image, | ||||
| question, | question, | ||||
| @@ -2142,7 +2143,7 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||||
| return topk_ids, topk_probs | return topk_ids, topk_probs | ||||
| class MPLUGForImageCaption(MPlug): | |||||
| class MPlugForImageCaption(MPlug): | |||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super().__init__(config) | super().__init__(config) | ||||
| @@ -2215,3 +2216,264 @@ class MPLUGForImageCaption(MPlug): | |||||
| else: | else: | ||||
| topk_ids, topk_probs = self.generation(image_embeds, image_atts) | topk_ids, topk_probs = self.generation(image_embeds, image_atts) | ||||
| return topk_ids, topk_probs | return topk_ids, topk_probs | ||||
| class MPlugForImageTextRetrieval(MPlug): | |||||
| def __init__(self, config): | |||||
| super().__init__(config) | |||||
| self.embed_dim = config.embed_dim | |||||
| self.temp = nn.Parameter(torch.ones([]) * config.temp) | |||||
| self.queue_size = config.queue_size | |||||
| self.momentum = config.momentum | |||||
| self.alpha = config.alpha | |||||
| self.queue_size = config.queue_size | |||||
| self.text_width = self.config_encoder.hidden_size | |||||
| self.embed_dim = config.embed_dim | |||||
| self.vision_proj = nn.Linear(self.text_width, self.embed_dim) | |||||
| self.text_proj = nn.Linear(self.text_width, self.embed_dim) | |||||
| self.itm_head = nn.Linear(self.text_width, 2) | |||||
| self.register_buffer('image_queue', | |||||
| torch.randn(self.embed_dim, self.queue_size)) | |||||
| self.register_buffer('text_queue', | |||||
| torch.randn(self.embed_dim, self.queue_size)) | |||||
| self.register_buffer('idx_queue', torch.full((1, self.queue_size), | |||||
| -100)) | |||||
| self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) | |||||
| self.image_queue = F.normalize(self.image_queue, dim=0) | |||||
| self.text_queue = F.normalize(self.text_queue, dim=0) | |||||
| self.init_distill(config) | |||||
| def init_distill(self, config): | |||||
| self.distill = config.distill | |||||
| if self.distill: | |||||
| self.visual_encoder_m = self._initialize_clip(config) | |||||
| self.text_encoder_m = BertModel( | |||||
| self.config_encoder, add_pooling_layer=False) | |||||
| self.fusion_encoder_m = FusionModel( | |||||
| self.config_fusion, add_pooling_layer=False) | |||||
| self.vision_proj_m = nn.Linear(self.text_width, self.embed_dim) | |||||
| self.text_proj_m = nn.Linear(self.text_width, self.embed_dim) | |||||
| self.model_pairs = [ | |||||
| [self.visual_encoder, self.visual_encoder_m], | |||||
| [self.text_encoder, self.text_encoder_m], | |||||
| [self.text_proj, self.text_proj_m], | |||||
| [self.vision_proj, self.vision_proj_m], | |||||
| ] | |||||
| if self.config_encoder.hidden_size != config.vision_width: | |||||
| self.visn_fc_m = nn.Linear(config.vision_width, | |||||
| self.config_encoder.hidden_size) | |||||
| self.visn_layer_norm_m = nn.LayerNorm( | |||||
| self.config_encoder.hidden_size, eps=1e-12) | |||||
| self.dropout_m = nn.Dropout( | |||||
| self.config_encoder.hidden_dropout_prob) | |||||
| self.model_pairs.extend( | |||||
| [[self.visn_fc, self.visn_fc_m], | |||||
| [self.visn_layer_norm, self.visn_layer_norm_m]]) | |||||
| self.copy_params() | |||||
| self.momentum = 0.995 | |||||
| @torch.no_grad() | |||||
| def _dequeue_and_enqueue(self, image_feat, text_feat, idx): | |||||
| def concat_all_gather(tensor): | |||||
| """ | |||||
| Performs all_gather operation on the provided tensors. | |||||
| *** Warning ***: torch.distributed.all_gather has no gradient. | |||||
| """ | |||||
| if not torch.distributed.is_initialized(): | |||||
| return tensor | |||||
| tensors_gather = [ | |||||
| torch.ones_like(tensor) | |||||
| for _ in range(torch.distributed.get_world_size()) | |||||
| ] | |||||
| torch.distributed.all_gather( | |||||
| tensors_gather, tensor, async_op=False) | |||||
| output = torch.cat(tensors_gather, dim=0) | |||||
| return output | |||||
| # gather keys before updating queue | |||||
| image_feats = concat_all_gather(image_feat) | |||||
| text_feats = concat_all_gather(text_feat) | |||||
| idxs = concat_all_gather(idx) | |||||
| batch_size = image_feats.shape[0] | |||||
| ptr = int(self.queue_ptr) | |||||
| # assert self.queue_size % batch_size == 0 # for simplicity | |||||
| # replace the keys at ptr (dequeue and enqueue) | |||||
| self.image_queue[:, ptr:ptr + batch_size] = image_feats.T | |||||
| self.text_queue[:, ptr:ptr + batch_size] = text_feats.T | |||||
| self.idx_queue[:, ptr:ptr + batch_size] = idxs.T | |||||
| ptr = (ptr + batch_size) % self.queue_size # move pointer | |||||
| self.queue_ptr[0] = ptr | |||||
| def forward(self, image, text, idx=None, train=True): | |||||
| if train: | |||||
| image_embeds = self.visual_encoder.visual( | |||||
| image, skip_last_layer=True) | |||||
| if self.large: | |||||
| image_embeds = self.dropout( | |||||
| self.visn_layer_norm(self.visn_fc(image_embeds))) | |||||
| image_atts = torch.ones( | |||||
| image_embeds.size()[:-1], dtype=torch.long).to(image.device) | |||||
| image_feat = F.normalize( | |||||
| self.vision_proj(image_embeds[:, 0, :]), dim=-1) | |||||
| text_output = self.text_encoder( | |||||
| text.input_ids, | |||||
| attention_mask=text.attention_mask, | |||||
| return_dict=True) | |||||
| text_embeds = text_output.last_hidden_state | |||||
| text_feat = F.normalize( | |||||
| self.text_proj(text_embeds[:, 0, :]), dim=-1) | |||||
| idx = idx.view(-1, 1) | |||||
| idx_all = torch.cat( | |||||
| [idx.t(), self.idx_queue.clone().detach()], dim=1) | |||||
| pos_idx = torch.eq(idx, idx_all).float() | |||||
| sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) | |||||
| with torch.no_grad(): | |||||
| self._momentum_update() | |||||
| image_embeds_m = self.visual_encoder_m.visual( | |||||
| image, skip_last_layer=True) | |||||
| if self.large: | |||||
| image_embeds_m = self.dropout_m( | |||||
| self.visn_layer_norm_m(self.visn_fc_m(image_embeds_m))) | |||||
| image_feat_m = F.normalize( | |||||
| self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1) | |||||
| image_feat_all = torch.cat( | |||||
| [image_feat_m.t(), | |||||
| self.image_queue.clone().detach()], | |||||
| dim=1) | |||||
| text_output_m = self.text_encoder_m( | |||||
| text.input_ids, | |||||
| attention_mask=text.attention_mask, | |||||
| return_dict=True) | |||||
| text_feat_m = F.normalize( | |||||
| self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), | |||||
| dim=-1) | |||||
| text_feat_all = torch.cat( | |||||
| [text_feat_m.t(), | |||||
| self.text_queue.clone().detach()], dim=1) | |||||
| if self.distill: | |||||
| sim_i2t_m = image_feat_m @ text_feat_all / self.temp | |||||
| sim_t2i_m = text_feat_m @ image_feat_all / self.temp | |||||
| sim_i2t_targets = self.alpha * F.softmax( | |||||
| sim_i2t_m, dim=1) + (1 - self.alpha) * sim_targets | |||||
| sim_t2i_targets = self.alpha * F.softmax( | |||||
| sim_t2i_m, dim=1) + (1 - self.alpha) * sim_targets | |||||
| sim_i2t = image_feat @ text_feat_all / self.temp | |||||
| sim_t2i = text_feat @ image_feat_all / self.temp | |||||
| if self.distill: | |||||
| loss_i2t = -torch.sum( | |||||
| F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, | |||||
| dim=1).mean() | |||||
| loss_t2i = -torch.sum( | |||||
| F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, | |||||
| dim=1).mean() | |||||
| else: | |||||
| loss_i2t = -torch.sum( | |||||
| F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() | |||||
| loss_t2i = -torch.sum( | |||||
| F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() | |||||
| loss_ita = (loss_i2t + loss_t2i) / 2 | |||||
| self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) | |||||
| # forward the positve image-text pair | |||||
| _, output_pos = self.fusion_encoder( | |||||
| encoder_embeds=text_embeds, | |||||
| attention_mask=text.attention_mask, | |||||
| encoder_hidden_states=image_embeds, | |||||
| encoder_attention_mask=image_atts, | |||||
| return_dict=False, | |||||
| ) | |||||
| with torch.no_grad(): | |||||
| bs = image.size(0) | |||||
| weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) | |||||
| weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) | |||||
| mask = torch.eq(idx, idx.T) | |||||
| weights_i2t.masked_fill_(mask, 0) | |||||
| weights_t2i.masked_fill_(mask, 0) | |||||
| # select a negative image for each text | |||||
| image_embeds_neg = [] | |||||
| for b in range(bs): | |||||
| neg_idx = torch.multinomial(weights_t2i[b], 1).item() | |||||
| image_embeds_neg.append(image_embeds[neg_idx]) | |||||
| image_embeds_neg = torch.stack(image_embeds_neg, dim=0) | |||||
| # select a negative text for each image | |||||
| text_embeds_neg = [] | |||||
| text_atts_neg = [] | |||||
| for b in range(bs): | |||||
| neg_idx = torch.multinomial(weights_i2t[b], 1).item() | |||||
| text_embeds_neg.append(text_embeds[neg_idx]) | |||||
| text_atts_neg.append(text.attention_mask[neg_idx]) | |||||
| text_embeds_neg = torch.stack(text_embeds_neg, dim=0) | |||||
| text_atts_neg = torch.stack(text_atts_neg, dim=0) | |||||
| text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) | |||||
| text_atts_all = torch.cat([text.attention_mask, text_atts_neg], | |||||
| dim=0) | |||||
| image_embeds_all = torch.cat([image_embeds_neg, image_embeds], | |||||
| dim=0) | |||||
| image_atts_all = torch.cat([image_atts, image_atts], dim=0) | |||||
| _, output_neg = self.fusion_encoder( | |||||
| encoder_embeds=text_embeds_all, | |||||
| attention_mask=text_atts_all, | |||||
| encoder_hidden_states=image_embeds_all, | |||||
| encoder_attention_mask=image_atts_all, | |||||
| return_dict=False, | |||||
| ) | |||||
| vl_embeddings = torch.cat( | |||||
| [output_pos[:, 0, :], output_neg[:, 0, :]], dim=0) | |||||
| vl_output = self.itm_head(vl_embeddings) | |||||
| ones_tmp = torch.ones(bs, dtype=torch.long) | |||||
| zeros_tmp = torch.zeros(2 * bs, dtype=torch.long) | |||||
| itm_labels = torch.cat([ones_tmp, zeros_tmp], | |||||
| dim=0).to(image.device) | |||||
| loss_itm = F.cross_entropy(vl_output, itm_labels) | |||||
| return loss_ita + loss_itm | |||||
| else: | |||||
| text_output = self.text_encoder( | |||||
| text.input_ids, attention_mask=text.attention_mask) | |||||
| text_feat = text_output.last_hidden_state | |||||
| image_feat = self.visual_encoder.visual( | |||||
| image, skip_last_layer=True) | |||||
| image_feat = self.visn_layer_norm(self.visn_fc(image_feat)) | |||||
| image_att = torch.ones( | |||||
| image_feat.size()[:-1], | |||||
| dtype=torch.long, | |||||
| device=image_feat.device) | |||||
| _, output = self.fusion_encoder( | |||||
| encoder_embeds=text_feat, | |||||
| attention_mask=text.attention_mask, | |||||
| encoder_hidden_states=image_feat, | |||||
| encoder_attention_mask=image_att, | |||||
| return_dict=False, | |||||
| ) | |||||
| scores = self.itm_head(output[:, 0, :]) | |||||
| scores = F.softmax(scores, dim=-1) | |||||
| return scores | |||||
| @@ -12,6 +12,7 @@ __all__ = ['MPlugForAllTasks'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.visual_question_answering, module_name=Models.mplug) | Tasks.visual_question_answering, module_name=Models.mplug) | ||||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) | @MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) | ||||
| @MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug) | |||||
| class MPlugForAllTasks(TorchModel): | class MPlugForAllTasks(TorchModel): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -43,39 +44,50 @@ class MPlugForAllTasks(TorchModel): | |||||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | ||||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | ||||
| if not self.training and 'answer_input_ids' not in input: | |||||
| topk_ids, _ = self.model(**input) | |||||
| # inference | |||||
| if not self.training and 'question' in input: | |||||
| output = self.model(input['image'], input['question'], train=False) | |||||
| if not isinstance(output, tuple): | |||||
| return output | |||||
| topk_ids, _ = output | |||||
| pred_string: str = self.tokenizer.decode(topk_ids[0][0]) | pred_string: str = self.tokenizer.decode(topk_ids[0][0]) | ||||
| for _old, _new in replace_tokens_bert: | for _old, _new in replace_tokens_bert: | ||||
| pred_string = pred_string.replace(_old, _new) | pred_string = pred_string.replace(_old, _new) | ||||
| pred_string = pred_string.strip() | pred_string = pred_string.strip() | ||||
| return pred_string | return pred_string | ||||
| else: | |||||
| import addict | |||||
| # train and evaluate | |||||
| import addict | |||||
| image = input['image'] | |||||
| answer = addict.Dict( | |||||
| input_ids=input['answer_input_ids'], | |||||
| attention_mask=input['answer_attention_mask']) | |||||
| if 'index' not in input: | |||||
| question = addict.Dict( | question = addict.Dict( | ||||
| input_ids=input['question_input_ids'], | input_ids=input['question_input_ids'], | ||||
| attention_mask=input['question_attention_mask']) | attention_mask=input['question_attention_mask']) | ||||
| answer = addict.Dict( | |||||
| input_ids=input['answer_input_ids'], | |||||
| attention_mask=input['answer_attention_mask']) | |||||
| output = self.model( | |||||
| input['image'], question, answer, train=self.training) | |||||
| if self.training: | |||||
| return {'loss': output} | |||||
| topk_ids, _ = output | |||||
| preds: List[str] = [ | |||||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||||
| ] | |||||
| for i in range(len(preds)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| preds[i] = preds[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| tgts: List[str] = [ | |||||
| self.tokenizer.decode(batch) | |||||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||||
| ] | |||||
| for i in range(len(tgts)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| tgts[i] = tgts[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| return {'preds': preds, 'tgts': tgts} | |||||
| output = self.model(image, question, answer, train=self.training) | |||||
| else: | |||||
| index = input['index'] | |||||
| output = self.model(image, answer, index, train=self.training) | |||||
| if self.training: | |||||
| return {'loss': output} | |||||
| # evaluate | |||||
| topk_ids, _ = output | |||||
| preds: List[str] = [ | |||||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||||
| ] | |||||
| for i in range(len(preds)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| preds[i] = preds[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| tgts: List[str] = [ | |||||
| self.tokenizer.decode(batch) | |||||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||||
| ] | |||||
| for i in range(len(tgts)): | |||||
| for _old, _new in replace_tokens_bert: | |||||
| tgts[i] = tgts[i].replace(_old, _new) | |||||
| preds[i] = preds[i].strip() | |||||
| return {'preds': preds, 'tgts': tgts} | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import MPlugPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_text_retrieval, module_name=Pipelines.image_text_retrieval) | |||||
| class ImageTextRetrievalPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create a | |||||
| image text retrieval pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| f'model must be a single str or Model, but got {type(model)}' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None: | |||||
| preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return {OutputKeys.SCORES: inputs[0].tolist()} | |||||
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict, List, Union | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| @@ -104,6 +104,7 @@ class MPlugPreprocessor(Preprocessor): | |||||
| self._tokenizer = None | self._tokenizer = None | ||||
| self._patch_resize_transform = None | self._patch_resize_transform = None | ||||
| self._image_map = {} | |||||
| @property | @property | ||||
| def tokenizer(self): | def tokenizer(self): | ||||
| @@ -133,31 +134,31 @@ class MPlugPreprocessor(Preprocessor): | |||||
| ]) | ]) | ||||
| return self._patch_resize_transform | return self._patch_resize_transform | ||||
| def __call__(self, *args, **kwargs): | |||||
| call_mapping = { | |||||
| Tasks.visual_question_answering: self.image_text_call, | |||||
| Tasks.image_captioning: self.image_text_call, | |||||
| } | |||||
| def image_open(self, path: str) -> Tuple[Image.Image, int]: | |||||
| if path not in self._image_map: | |||||
| index = len(self._image_map) | |||||
| self._image_map[path] = (Image.open(path), index) | |||||
| return self._image_map[path] | |||||
| def __call__( | |||||
| self, data: Union[Image.Image, tuple, | |||||
| Dict[str, Any]]) -> Dict[str, Any]: | |||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(self.model_dir, ModelFile.CONFIGURATION)) | osp.join(self.model_dir, ModelFile.CONFIGURATION)) | ||||
| return call_mapping[self.cfg.task](*args, **kwargs) | |||||
| def image_text_call( | |||||
| self, data: Union[Image.Image, tuple, | |||||
| Dict[str, Any]]) -> Dict[str, Any]: | |||||
| if isinstance(data, (Image.Image, str)): | if isinstance(data, (Image.Image, str)): | ||||
| image = data | image = data | ||||
| elif isinstance(data, tuple): | elif isinstance(data, tuple): | ||||
| image = data[0] | image = data[0] | ||||
| else: | else: | ||||
| image = data['image'] | image = data['image'] | ||||
| index = 0 | |||||
| if isinstance(image, str): | if isinstance(image, str): | ||||
| image = Image.open(image) | |||||
| question = '' if self.cfg.task != Tasks.visual_question_answering \ | |||||
| else data[1 if isinstance(data, tuple) else 'question'] | |||||
| image, index = self.image_open(image) | |||||
| image = image.convert('RGB') | image = image.convert('RGB') | ||||
| image = self.patch_resize_transform(image) | image = self.patch_resize_transform(image) | ||||
| question = '' if self.cfg.task == Tasks.image_captioning \ | |||||
| else data[1 if isinstance(data, tuple) else 'question'] | |||||
| question = self.tokenizer( | question = self.tokenizer( | ||||
| question.lower(), | question.lower(), | ||||
| padding='max_length', | padding='max_length', | ||||
| @@ -167,7 +168,7 @@ class MPlugPreprocessor(Preprocessor): | |||||
| if self.mode == ModeKeys.INFERENCE: | if self.mode == ModeKeys.INFERENCE: | ||||
| image = torch.stack([image], dim=0) | image = torch.stack([image], dim=0) | ||||
| return {'image': image, 'question': question, 'train': False} | |||||
| return {'image': image, 'question': question} | |||||
| else: | else: | ||||
| answer = data['answer'] | answer = data['answer'] | ||||
| answer = self.tokenizer( | answer = self.tokenizer( | ||||
| @@ -176,10 +177,13 @@ class MPlugPreprocessor(Preprocessor): | |||||
| truncation=True, | truncation=True, | ||||
| max_length=self.tokenizer_max_length, | max_length=self.tokenizer_max_length, | ||||
| return_tensors='pt') | return_tensors='pt') | ||||
| return { | |||||
| output = { | |||||
| 'image': image, | 'image': image, | ||||
| 'question_input_ids': question.input_ids.squeeze(), | 'question_input_ids': question.input_ids.squeeze(), | ||||
| 'question_attention_mask': question.attention_mask.squeeze(), | 'question_attention_mask': question.attention_mask.squeeze(), | ||||
| 'answer_input_ids': answer.input_ids.squeeze(), | 'answer_input_ids': answer.input_ids.squeeze(), | ||||
| 'answer_attention_mask': answer.attention_mask.squeeze(), | 'answer_attention_mask': answer.attention_mask.squeeze(), | ||||
| } | } | ||||
| if self.cfg.task == Tasks.image_text_retrieval: | |||||
| output['index'] = index | |||||
| return output | |||||
| @@ -121,6 +121,7 @@ class MultiModalTasks(object): | |||||
| visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
| visual_entailment = 'visual-entailment' | visual_entailment = 'visual-entailment' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | |||||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | ||||
| @@ -54,6 +54,27 @@ class MplugTasksTest(unittest.TestCase): | |||||
| result = pipeline_vqa(input) | result = pipeline_vqa(input) | ||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_image_text_retrieval_with_model(self): | |||||
| model = Model.from_pretrained( | |||||
| 'damo/mplug_image-text-retrieval_flickr30k_large_en') | |||||
| pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) | |||||
| image = Image.open('data/test/images/image-text-retrieval.jpg') | |||||
| question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' | |||||
| input = {'image': image, 'question': question} | |||||
| result = pipeline_retrieval(input) | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_image_text_retrieval_with_name(self): | |||||
| model = 'damo/mplug_image-text-retrieval_flickr30k_large_en' | |||||
| pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) | |||||
| image = Image.open('data/test/images/image-text-retrieval.jpg') | |||||
| question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' | |||||
| input = {'image': image, 'question': question} | |||||
| result = pipeline_retrieval(input) | |||||
| print(result) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -4,8 +4,6 @@ import shutil | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from PIL import Image | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.multi_modal import MPlugForAllTasks | from modelscope.models.multi_modal import MPlugForAllTasks | ||||
| @@ -23,7 +21,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| datadict = MsDataset.load('coco_captions_small_slice') | |||||
| from modelscope.utils.constant import DownloadMode | |||||
| datadict = MsDataset.load( | |||||
| 'coco_captions_small_slice', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
| self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | ||||
| lambda _: { | lambda _: { | ||||
| 'question': 'what the picture describes?' | 'question': 'what the picture describes?' | ||||
| @@ -35,17 +36,19 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| }).rename_column('image:FILE', | }).rename_column('image:FILE', | ||||
| 'image').rename_column('answer:Value', 'answer')) | 'image').rename_column('answer:Value', 'answer')) | ||||
| self.max_epochs = 3 | |||||
| def tearDown(self): | def tearDown(self): | ||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| super().tearDown() | super().tearDown() | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_trainer_with_caption(self): | def test_trainer_with_caption(self): | ||||
| kwargs = dict( | kwargs = dict( | ||||
| model='damo/mplug_image-captioning_coco_base_en', | model='damo/mplug_image-captioning_coco_base_en', | ||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| @@ -53,15 +56,11 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| for i in range(3): | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_trainer_with_caption_with_model_and_args(self): | def test_trainer_with_caption_with_model_and_args(self): | ||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| cache_path = snapshot_download( | cache_path = snapshot_download( | ||||
| 'damo/mplug_image-captioning_coco_base_en') | 'damo/mplug_image-captioning_coco_base_en') | ||||
| model = MPlugForAllTasks.from_pretrained(cache_path) | model = MPlugForAllTasks.from_pretrained(cache_path) | ||||
| @@ -70,7 +69,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| model=model, | model=model, | ||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=2, | |||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| @@ -78,16 +77,16 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| for i in range(2): | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_trainer_with_vqa(self): | def test_trainer_with_vqa(self): | ||||
| kwargs = dict( | kwargs = dict( | ||||
| model='damo/mplug_visual-question-answering_coco_large_en', | model='damo/mplug_visual-question-answering_coco_large_en', | ||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| @@ -95,15 +94,11 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| for i in range(3): | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_trainer_with_vqa_with_model_and_args(self): | def test_trainer_with_vqa_with_model_and_args(self): | ||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| cache_path = snapshot_download( | cache_path = snapshot_download( | ||||
| 'damo/mplug_visual-question-answering_coco_large_en') | 'damo/mplug_visual-question-answering_coco_large_en') | ||||
| model = MPlugForAllTasks.from_pretrained(cache_path) | model = MPlugForAllTasks.from_pretrained(cache_path) | ||||
| @@ -112,7 +107,45 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| model=model, | model=model, | ||||
| train_dataset=self.train_dataset, | train_dataset=self.train_dataset, | ||||
| eval_dataset=self.test_dataset, | eval_dataset=self.test_dataset, | ||||
| max_epochs=2, | |||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_with_retrieval(self): | |||||
| kwargs = dict( | |||||
| model='damo/mplug_image-text-retrieval_flickr30k_large_en', | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer: EpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_retrieval_with_model_and_args(self): | |||||
| cache_path = snapshot_download( | |||||
| 'damo/mplug_image-text-retrieval_flickr30k_large_en') | |||||
| model = MPlugForAllTasks.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| max_epochs=self.max_epochs, | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer: EpochBasedTrainer = build_trainer( | trainer: EpochBasedTrainer = build_trainer( | ||||
| @@ -120,7 +153,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| for i in range(2): | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||