Browse Source

[to #42322933] Add mplug retrieval pipeline and finetune

支持 MPLUG 模型 image-text-retrieval 任务的 pipeline 和 finetune
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9919955
master
hemu.zp yingda.chen 3 years ago
parent
commit
a9089570e5
10 changed files with 487 additions and 91 deletions
  1. +3
    -0
      data/test/images/image-text-retrieval.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +8
    -0
      modelscope/models/multi_modal/mplug/configuration_mplug.py
  4. +291
    -29
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  5. +40
    -28
      modelscope/models/multi_modal/mplug_for_all_tasks.py
  6. +51
    -0
      modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py
  7. +19
    -15
      modelscope/preprocessors/multi_modal.py
  8. +1
    -0
      modelscope/utils/constant.py
  9. +21
    -0
      tests/pipelines/test_mplug_tasks.py
  10. +52
    -19
      tests/trainers/test_finetune_mplug.py

+ 3
- 0
data/test/images/image-text-retrieval.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab
size 218143

+ 1
- 0
modelscope/metainfo.py View File

@@ -170,6 +170,7 @@ class Pipelines(object):
multi_modal_similarity = 'multi-modal-similarity' multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis' text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding' video_multi_modal_embedding = 'video-multi-modal-embedding'
image_text_retrieval = 'image-text-retrieval'




class Trainers(object): class Trainers(object):


+ 8
- 0
modelscope/models/multi_modal/mplug/configuration_mplug.py View File

@@ -64,6 +64,10 @@ class MPlugConfig(PretrainedConfig):
clip_transformer_width=768, clip_transformer_width=768,
clip_transformer_heads=12, clip_transformer_heads=12,
clip_transformer_layers=12, clip_transformer_layers=12,
# retrieval
queue_size=65536,
embed_dim=256,
temp=0.07,
**kwargs): **kwargs):


super().__init__(**kwargs) super().__init__(**kwargs)
@@ -99,6 +103,10 @@ class MPlugConfig(PretrainedConfig):
self.clip_transformer_width = clip_transformer_width self.clip_transformer_width = clip_transformer_width
self.clip_transformer_heads = clip_transformer_heads self.clip_transformer_heads = clip_transformer_heads
self.clip_transformer_layers = clip_transformer_layers self.clip_transformer_layers = clip_transformer_layers
# retrieval
self.queue_size = queue_size
self.embed_dim = embed_dim
self.temp = temp


@classmethod @classmethod
def from_yaml_file(cls, yaml_file: Union[str, def from_yaml_file(cls, yaml_file: Union[str,


+ 291
- 29
modelscope/models/multi_modal/mplug/modeling_mplug.py View File

@@ -1855,7 +1855,8 @@ class MPlug(PreTrainedModel):


task_mapping = { task_mapping = {
Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, Tasks.visual_question_answering: MPlugForVisualQuestionAnswering,
Tasks.image_captioning: MPLUGForImageCaption
Tasks.image_captioning: MPlugForImageCaption,
Tasks.image_text_retrieval: MPlugForImageTextRetrieval,
} }
config = cls.config_class.from_yaml_file( config = cls.config_class.from_yaml_file(
os.path.join(model_dir, CONFIG_NAME)) os.path.join(model_dir, CONFIG_NAME))
@@ -1915,6 +1916,33 @@ class MPlug(PreTrainedModel):
clip_model.visual.positional_embedding = pos_embed clip_model.visual.positional_embedding = pos_embed
return clip_model return clip_model


def init_distill(self, config):
self.distill = config.distill
if self.distill:
self.visual_encoder_m = self._initialize_clip(config)
self.text_encoder_m = BertModel(
self.config_encoder, add_pooling_layer=False)
self.fusion_encoder_m = FusionModel(
self.config_fusion, add_pooling_layer=False)
self.text_decoder_m = BertLMHeadModel(self.config_decoder)
self.model_pairs = [
[self.visual_encoder, self.visual_encoder_m],
[self.text_encoder, self.text_encoder_m],
[self.text_decoder, self.text_decoder_m],
]
if self.config_encoder.hidden_size != config.vision_width:
self.visn_fc_m = nn.Linear(config.vision_width,
self.config_encoder.hidden_size)
self.visn_layer_norm_m = nn.LayerNorm(
self.config_encoder.hidden_size, eps=1e-12)
self.dropout_m = nn.Dropout(
self.config_encoder.hidden_dropout_prob)
self.model_pairs.extend(
[[self.visn_fc, self.visn_fc_m],
[self.visn_layer_norm, self.visn_layer_norm_m]])
self.copy_params()
self.momentum = 0.995

def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


@@ -1978,33 +2006,6 @@ class MPlugForVisualQuestionAnswering(MPlug):
self.beam_generator = TextGenerator(config, self.text_decoder) self.beam_generator = TextGenerator(config, self.text_decoder)
self.init_distill(config) self.init_distill(config)


def init_distill(self, config):
self.distill = config.distill
if self.distill:
self.visual_encoder_m = self._initialize_clip(config)
self.text_encoder_m = BertModel(
self.config_encoder, add_pooling_layer=False)
self.fusion_encoder_m = FusionModel(
self.config_fusion, add_pooling_layer=False)
self.text_decoder_m = BertLMHeadModel(self.config_decoder)
self.model_pairs = [
[self.visual_encoder, self.visual_encoder_m],
[self.text_encoder, self.text_encoder_m],
[self.text_decoder, self.text_decoder_m],
]
if self.config_encoder.hidden_size != config.vision_width:
self.visn_fc_m = nn.Linear(config.vision_width,
self.config_encoder.hidden_size)
self.visn_layer_norm_m = nn.LayerNorm(
self.config_encoder.hidden_size, eps=1e-12)
self.dropout_m = nn.Dropout(
self.config_encoder.hidden_dropout_prob)
self.model_pairs.extend(
[[self.visn_fc, self.visn_fc_m],
[self.visn_layer_norm, self.visn_layer_norm_m]])
self.copy_params()
self.momentum = 0.995

def forward(self, def forward(self,
image, image,
question, question,
@@ -2142,7 +2143,7 @@ class MPlugForVisualQuestionAnswering(MPlug):
return topk_ids, topk_probs return topk_ids, topk_probs




class MPLUGForImageCaption(MPlug):
class MPlugForImageCaption(MPlug):


def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -2215,3 +2216,264 @@ class MPLUGForImageCaption(MPlug):
else: else:
topk_ids, topk_probs = self.generation(image_embeds, image_atts) topk_ids, topk_probs = self.generation(image_embeds, image_atts)
return topk_ids, topk_probs return topk_ids, topk_probs


class MPlugForImageTextRetrieval(MPlug):

def __init__(self, config):
super().__init__(config)
self.embed_dim = config.embed_dim
self.temp = nn.Parameter(torch.ones([]) * config.temp)
self.queue_size = config.queue_size
self.momentum = config.momentum
self.alpha = config.alpha

self.queue_size = config.queue_size
self.text_width = self.config_encoder.hidden_size
self.embed_dim = config.embed_dim

self.vision_proj = nn.Linear(self.text_width, self.embed_dim)
self.text_proj = nn.Linear(self.text_width, self.embed_dim)
self.itm_head = nn.Linear(self.text_width, 2)

self.register_buffer('image_queue',
torch.randn(self.embed_dim, self.queue_size))
self.register_buffer('text_queue',
torch.randn(self.embed_dim, self.queue_size))
self.register_buffer('idx_queue', torch.full((1, self.queue_size),
-100))
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

self.image_queue = F.normalize(self.image_queue, dim=0)
self.text_queue = F.normalize(self.text_queue, dim=0)
self.init_distill(config)

def init_distill(self, config):
self.distill = config.distill
if self.distill:
self.visual_encoder_m = self._initialize_clip(config)
self.text_encoder_m = BertModel(
self.config_encoder, add_pooling_layer=False)
self.fusion_encoder_m = FusionModel(
self.config_fusion, add_pooling_layer=False)
self.vision_proj_m = nn.Linear(self.text_width, self.embed_dim)
self.text_proj_m = nn.Linear(self.text_width, self.embed_dim)
self.model_pairs = [
[self.visual_encoder, self.visual_encoder_m],
[self.text_encoder, self.text_encoder_m],
[self.text_proj, self.text_proj_m],
[self.vision_proj, self.vision_proj_m],
]
if self.config_encoder.hidden_size != config.vision_width:
self.visn_fc_m = nn.Linear(config.vision_width,
self.config_encoder.hidden_size)
self.visn_layer_norm_m = nn.LayerNorm(
self.config_encoder.hidden_size, eps=1e-12)
self.dropout_m = nn.Dropout(
self.config_encoder.hidden_dropout_prob)
self.model_pairs.extend(
[[self.visn_fc, self.visn_fc_m],
[self.visn_layer_norm, self.visn_layer_norm_m]])
self.copy_params()
self.momentum = 0.995

@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idx):

def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
if not torch.distributed.is_initialized():
return tensor
tensors_gather = [
torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
tensors_gather, tensor, async_op=False)

output = torch.cat(tensors_gather, dim=0)
return output

# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
idxs = concat_all_gather(idx)

batch_size = image_feats.shape[0]

ptr = int(self.queue_ptr)
# assert self.queue_size % batch_size == 0 # for simplicity

# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer

self.queue_ptr[0] = ptr

def forward(self, image, text, idx=None, train=True):
if train:
image_embeds = self.visual_encoder.visual(
image, skip_last_layer=True)
if self.large:
image_embeds = self.dropout(
self.visn_layer_norm(self.visn_fc(image_embeds)))
image_atts = torch.ones(
image_embeds.size()[:-1], dtype=torch.long).to(image.device)

image_feat = F.normalize(
self.vision_proj(image_embeds[:, 0, :]), dim=-1)
text_output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(
self.text_proj(text_embeds[:, 0, :]), dim=-1)

idx = idx.view(-1, 1)
idx_all = torch.cat(
[idx.t(), self.idx_queue.clone().detach()], dim=1)
pos_idx = torch.eq(idx, idx_all).float()
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)

with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m.visual(
image, skip_last_layer=True)
if self.large:
image_embeds_m = self.dropout_m(
self.visn_layer_norm_m(self.visn_fc_m(image_embeds_m)))
image_feat_m = F.normalize(
self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1)
image_feat_all = torch.cat(
[image_feat_m.t(),
self.image_queue.clone().detach()],
dim=1)
text_output_m = self.text_encoder_m(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True)
text_feat_m = F.normalize(
self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]),
dim=-1)
text_feat_all = torch.cat(
[text_feat_m.t(),
self.text_queue.clone().detach()], dim=1)

if self.distill:
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
sim_t2i_m = text_feat_m @ image_feat_all / self.temp

sim_i2t_targets = self.alpha * F.softmax(
sim_i2t_m, dim=1) + (1 - self.alpha) * sim_targets
sim_t2i_targets = self.alpha * F.softmax(
sim_t2i_m, dim=1) + (1 - self.alpha) * sim_targets

sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp

if self.distill:
loss_i2t = -torch.sum(
F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets,
dim=1).mean()
loss_t2i = -torch.sum(
F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets,
dim=1).mean()
else:
loss_i2t = -torch.sum(
F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
loss_t2i = -torch.sum(
F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()

loss_ita = (loss_i2t + loss_t2i) / 2

self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)

# forward the positve image-text pair
_, output_pos = self.fusion_encoder(
encoder_embeds=text_embeds,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=False,
)
with torch.no_grad():
bs = image.size(0)
weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1)
weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1)

mask = torch.eq(idx, idx.T)
weights_i2t.masked_fill_(mask, 0)
weights_t2i.masked_fill_(mask, 0)

# select a negative image for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

# select a negative text for each image
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_embeds_neg.append(text_embeds[neg_idx])
text_atts_neg.append(text.attention_mask[neg_idx])
text_embeds_neg = torch.stack(text_embeds_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)

text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],
dim=0)

image_embeds_all = torch.cat([image_embeds_neg, image_embeds],
dim=0)
image_atts_all = torch.cat([image_atts, image_atts], dim=0)

_, output_neg = self.fusion_encoder(
encoder_embeds=text_embeds_all,
attention_mask=text_atts_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=False,
)

vl_embeddings = torch.cat(
[output_pos[:, 0, :], output_neg[:, 0, :]], dim=0)
vl_output = self.itm_head(vl_embeddings)

ones_tmp = torch.ones(bs, dtype=torch.long)
zeros_tmp = torch.zeros(2 * bs, dtype=torch.long)
itm_labels = torch.cat([ones_tmp, zeros_tmp],
dim=0).to(image.device)
loss_itm = F.cross_entropy(vl_output, itm_labels)

return loss_ita + loss_itm
else:
text_output = self.text_encoder(
text.input_ids, attention_mask=text.attention_mask)
text_feat = text_output.last_hidden_state
image_feat = self.visual_encoder.visual(
image, skip_last_layer=True)
image_feat = self.visn_layer_norm(self.visn_fc(image_feat))
image_att = torch.ones(
image_feat.size()[:-1],
dtype=torch.long,
device=image_feat.device)
_, output = self.fusion_encoder(
encoder_embeds=text_feat,
attention_mask=text.attention_mask,
encoder_hidden_states=image_feat,
encoder_attention_mask=image_att,
return_dict=False,
)
scores = self.itm_head(output[:, 0, :])
scores = F.softmax(scores, dim=-1)

return scores

+ 40
- 28
modelscope/models/multi_modal/mplug_for_all_tasks.py View File

@@ -12,6 +12,7 @@ __all__ = ['MPlugForAllTasks']
@MODELS.register_module( @MODELS.register_module(
Tasks.visual_question_answering, module_name=Models.mplug) Tasks.visual_question_answering, module_name=Models.mplug)
@MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) @MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug)
@MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug)
class MPlugForAllTasks(TorchModel): class MPlugForAllTasks(TorchModel):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
@@ -43,39 +44,50 @@ class MPlugForAllTasks(TorchModel):
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))


if not self.training and 'answer_input_ids' not in input:
topk_ids, _ = self.model(**input)
# inference
if not self.training and 'question' in input:
output = self.model(input['image'], input['question'], train=False)
if not isinstance(output, tuple):
return output
topk_ids, _ = output
pred_string: str = self.tokenizer.decode(topk_ids[0][0]) pred_string: str = self.tokenizer.decode(topk_ids[0][0])
for _old, _new in replace_tokens_bert: for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new) pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip() pred_string = pred_string.strip()
return pred_string return pred_string
else:
import addict

# train and evaluate
import addict
image = input['image']
answer = addict.Dict(
input_ids=input['answer_input_ids'],
attention_mask=input['answer_attention_mask'])
if 'index' not in input:
question = addict.Dict( question = addict.Dict(
input_ids=input['question_input_ids'], input_ids=input['question_input_ids'],
attention_mask=input['question_attention_mask']) attention_mask=input['question_attention_mask'])
answer = addict.Dict(
input_ids=input['answer_input_ids'],
attention_mask=input['answer_attention_mask'])
output = self.model(
input['image'], question, answer, train=self.training)
if self.training:
return {'loss': output}
topk_ids, _ = output
preds: List[str] = [
self.tokenizer.decode(batch[0]) for batch in topk_ids
]
for i in range(len(preds)):
for _old, _new in replace_tokens_bert:
preds[i] = preds[i].replace(_old, _new)
preds[i] = preds[i].strip()
tgts: List[str] = [
self.tokenizer.decode(batch)
for batch in input['answer_input_ids'].cpu().numpy().tolist()
]
for i in range(len(tgts)):
for _old, _new in replace_tokens_bert:
tgts[i] = tgts[i].replace(_old, _new)
preds[i] = preds[i].strip()
return {'preds': preds, 'tgts': tgts}
output = self.model(image, question, answer, train=self.training)
else:
index = input['index']
output = self.model(image, answer, index, train=self.training)
if self.training:
return {'loss': output}

# evaluate
topk_ids, _ = output
preds: List[str] = [
self.tokenizer.decode(batch[0]) for batch in topk_ids
]
for i in range(len(preds)):
for _old, _new in replace_tokens_bert:
preds[i] = preds[i].replace(_old, _new)
preds[i] = preds[i].strip()
tgts: List[str] = [
self.tokenizer.decode(batch)
for batch in input['answer_input_ids'].cpu().numpy().tolist()
]
for i in range(len(tgts)):
for _old, _new in replace_tokens_bert:
tgts[i] = tgts[i].replace(_old, _new)
preds[i] = preds[i].strip()
return {'preds': preds, 'tgts': tgts}

+ 51
- 0
modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py View File

@@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import MPlugPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_text_retrieval, module_name=Pipelines.image_text_retrieval)
class ImageTextRetrievalPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a
image text retrieval pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
f'model must be a single str or Model, but got {type(model)}'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None:
preprocessor = MPlugPreprocessor(pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {OutputKeys.SCORES: inputs[0].tolist()}

+ 19
- 15
modelscope/preprocessors/multi_modal.py View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp import os.path as osp
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Tuple, Union


import torch import torch
from PIL import Image from PIL import Image
@@ -104,6 +104,7 @@ class MPlugPreprocessor(Preprocessor):


self._tokenizer = None self._tokenizer = None
self._patch_resize_transform = None self._patch_resize_transform = None
self._image_map = {}


@property @property
def tokenizer(self): def tokenizer(self):
@@ -133,31 +134,31 @@ class MPlugPreprocessor(Preprocessor):
]) ])
return self._patch_resize_transform return self._patch_resize_transform


def __call__(self, *args, **kwargs):
call_mapping = {
Tasks.visual_question_answering: self.image_text_call,
Tasks.image_captioning: self.image_text_call,
}
def image_open(self, path: str) -> Tuple[Image.Image, int]:
if path not in self._image_map:
index = len(self._image_map)
self._image_map[path] = (Image.open(path), index)
return self._image_map[path]


def __call__(
self, data: Union[Image.Image, tuple,
Dict[str, Any]]) -> Dict[str, Any]:
self.cfg = Config.from_file( self.cfg = Config.from_file(
osp.join(self.model_dir, ModelFile.CONFIGURATION)) osp.join(self.model_dir, ModelFile.CONFIGURATION))
return call_mapping[self.cfg.task](*args, **kwargs)


def image_text_call(
self, data: Union[Image.Image, tuple,
Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(data, (Image.Image, str)): if isinstance(data, (Image.Image, str)):
image = data image = data
elif isinstance(data, tuple): elif isinstance(data, tuple):
image = data[0] image = data[0]
else: else:
image = data['image'] image = data['image']
index = 0
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image)
question = '' if self.cfg.task != Tasks.visual_question_answering \
else data[1 if isinstance(data, tuple) else 'question']
image, index = self.image_open(image)
image = image.convert('RGB') image = image.convert('RGB')
image = self.patch_resize_transform(image) image = self.patch_resize_transform(image)
question = '' if self.cfg.task == Tasks.image_captioning \
else data[1 if isinstance(data, tuple) else 'question']
question = self.tokenizer( question = self.tokenizer(
question.lower(), question.lower(),
padding='max_length', padding='max_length',
@@ -167,7 +168,7 @@ class MPlugPreprocessor(Preprocessor):


if self.mode == ModeKeys.INFERENCE: if self.mode == ModeKeys.INFERENCE:
image = torch.stack([image], dim=0) image = torch.stack([image], dim=0)
return {'image': image, 'question': question, 'train': False}
return {'image': image, 'question': question}
else: else:
answer = data['answer'] answer = data['answer']
answer = self.tokenizer( answer = self.tokenizer(
@@ -176,10 +177,13 @@ class MPlugPreprocessor(Preprocessor):
truncation=True, truncation=True,
max_length=self.tokenizer_max_length, max_length=self.tokenizer_max_length,
return_tensors='pt') return_tensors='pt')
return {
output = {
'image': image, 'image': image,
'question_input_ids': question.input_ids.squeeze(), 'question_input_ids': question.input_ids.squeeze(),
'question_attention_mask': question.attention_mask.squeeze(), 'question_attention_mask': question.attention_mask.squeeze(),
'answer_input_ids': answer.input_ids.squeeze(), 'answer_input_ids': answer.input_ids.squeeze(),
'answer_attention_mask': answer.attention_mask.squeeze(), 'answer_attention_mask': answer.attention_mask.squeeze(),
} }
if self.cfg.task == Tasks.image_text_retrieval:
output['index'] = index
return output

+ 1
- 0
modelscope/utils/constant.py View File

@@ -121,6 +121,7 @@ class MultiModalTasks(object):
visual_question_answering = 'visual-question-answering' visual_question_answering = 'visual-question-answering'
visual_entailment = 'visual-entailment' visual_entailment = 'visual-entailment'
video_multi_modal_embedding = 'video-multi-modal-embedding' video_multi_modal_embedding = 'video-multi-modal-embedding'
image_text_retrieval = 'image-text-retrieval'




class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks):


+ 21
- 0
tests/pipelines/test_mplug_tasks.py View File

@@ -54,6 +54,27 @@ class MplugTasksTest(unittest.TestCase):
result = pipeline_vqa(input) result = pipeline_vqa(input)
print(result) print(result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_text_retrieval_with_model(self):
model = Model.from_pretrained(
'damo/mplug_image-text-retrieval_flickr30k_large_en')
pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model)
image = Image.open('data/test/images/image-text-retrieval.jpg')
question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.'
input = {'image': image, 'question': question}
result = pipeline_retrieval(input)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_image_text_retrieval_with_name(self):
model = 'damo/mplug_image-text-retrieval_flickr30k_large_en'
pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model)
image = Image.open('data/test/images/image-text-retrieval.jpg')
question = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.'
input = {'image': image, 'question': question}
result = pipeline_retrieval(input)
print(result)



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

+ 52
- 19
tests/trainers/test_finetune_mplug.py View File

@@ -4,8 +4,6 @@ import shutil
import tempfile import tempfile
import unittest import unittest


from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers from modelscope.metainfo import Trainers
from modelscope.models.multi_modal import MPlugForAllTasks from modelscope.models.multi_modal import MPlugForAllTasks
@@ -23,7 +21,10 @@ class TestFinetuneMPlug(unittest.TestCase):
if not os.path.exists(self.tmp_dir): if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir) os.makedirs(self.tmp_dir)


datadict = MsDataset.load('coco_captions_small_slice')
from modelscope.utils.constant import DownloadMode
datadict = MsDataset.load(
'coco_captions_small_slice',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map(
lambda _: { lambda _: {
'question': 'what the picture describes?' 'question': 'what the picture describes?'
@@ -35,17 +36,19 @@ class TestFinetuneMPlug(unittest.TestCase):
}).rename_column('image:FILE', }).rename_column('image:FILE',
'image').rename_column('answer:Value', 'answer')) 'image').rename_column('answer:Value', 'answer'))


self.max_epochs = 3

def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmp_dir) shutil.rmtree(self.tmp_dir)
super().tearDown() super().tearDown()


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_caption(self): def test_trainer_with_caption(self):

kwargs = dict( kwargs = dict(
model='damo/mplug_image-captioning_coco_base_en', model='damo/mplug_image-captioning_coco_base_en',
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.test_dataset, eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer: EpochBasedTrainer = build_trainer( trainer: EpochBasedTrainer = build_trainer(
@@ -53,15 +56,11 @@ class TestFinetuneMPlug(unittest.TestCase):
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_caption_with_model_and_args(self): def test_trainer_with_caption_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download( cache_path = snapshot_download(
'damo/mplug_image-captioning_coco_base_en') 'damo/mplug_image-captioning_coco_base_en')
model = MPlugForAllTasks.from_pretrained(cache_path) model = MPlugForAllTasks.from_pretrained(cache_path)
@@ -70,7 +69,7 @@ class TestFinetuneMPlug(unittest.TestCase):
model=model, model=model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.test_dataset, eval_dataset=self.test_dataset,
max_epochs=2,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer: EpochBasedTrainer = build_trainer( trainer: EpochBasedTrainer = build_trainer(
@@ -78,16 +77,16 @@ class TestFinetuneMPlug(unittest.TestCase):
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_vqa(self): def test_trainer_with_vqa(self):

kwargs = dict( kwargs = dict(
model='damo/mplug_visual-question-answering_coco_large_en', model='damo/mplug_visual-question-answering_coco_large_en',
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.test_dataset, eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer: EpochBasedTrainer = build_trainer( trainer: EpochBasedTrainer = build_trainer(
@@ -95,15 +94,11 @@ class TestFinetuneMPlug(unittest.TestCase):
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_vqa_with_model_and_args(self): def test_trainer_with_vqa_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download( cache_path = snapshot_download(
'damo/mplug_visual-question-answering_coco_large_en') 'damo/mplug_visual-question-answering_coco_large_en')
model = MPlugForAllTasks.from_pretrained(cache_path) model = MPlugForAllTasks.from_pretrained(cache_path)
@@ -112,7 +107,45 @@ class TestFinetuneMPlug(unittest.TestCase):
model=model, model=model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.test_dataset, eval_dataset=self.test_dataset,
max_epochs=2,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_retrieval(self):
kwargs = dict(
model='damo/mplug_image-text-retrieval_flickr30k_large_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_retrieval_with_model_and_args(self):
cache_path = snapshot_download(
'damo/mplug_image-text-retrieval_flickr30k_large_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer: EpochBasedTrainer = build_trainer( trainer: EpochBasedTrainer = build_trainer(
@@ -120,7 +153,7 @@ class TestFinetuneMPlug(unittest.TestCase):
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)






Loading…
Cancel
Save