eniac.xcw yingda.chen 3 years ago
parent
commit
8c9348de2c
10 changed files with 543 additions and 0 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/multi_modal/__init__.py
  3. +1
    -0
      modelscope/models/multi_modal/team/__init__.py
  4. +126
    -0
      modelscope/models/multi_modal/team/team_model.py
  5. +326
    -0
      modelscope/models/multi_modal/team/utils.py
  6. +9
    -0
      modelscope/outputs.py
  7. +3
    -0
      modelscope/pipelines/builder.py
  8. +31
    -0
      modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +42
    -0
      tests/pipelines/test_multi_modal_similarity.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -59,6 +59,7 @@ class Models(object):
gemm = 'gemm-generative-multi-modal'
mplug = 'mplug'
diffusion = 'diffusion-text-to-image-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'


@@ -166,6 +167,7 @@ class Pipelines(object):
visual_question_answering = 'visual-question-answering'
visual_grounding = 'visual-grounding'
visual_entailment = 'visual-entailment'
multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'



+ 2
- 0
modelscope/models/multi_modal/__init__.py View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:

from .clip import CLIPForMultiModalEmbedding
from .gemm import GEMMForMultiModalEmbedding
from .team import TEAMForMultiModalSimilarity
from .diffusion import DiffusionForTextToImageSynthesis
from .mmr import VideoCLIPForMultiModalEmbedding
from .mplug_for_all_tasks import MPlugForAllTasks
@@ -19,6 +20,7 @@ else:
'clip': ['CLIPForMultiModalEmbedding'],
'diffusion': ['DiffusionForTextToImageSynthesis'],
'gemm': ['GEMMForMultiModalEmbedding'],
'team': ['TEAMForMultiModalSimilarity'],
'mmr': ['VideoCLIPForMultiModalEmbedding'],
'mplug_for_all_tasks': ['MPlugForAllTasks'],
'ofa_for_all_tasks': ['OfaForAllTasks'],


+ 1
- 0
modelscope/models/multi_modal/team/__init__.py View File

@@ -0,0 +1 @@
from .team_model import TEAMForMultiModalSimilarity

+ 126
- 0
modelscope/models/multi_modal/team/team_model.py View File

@@ -0,0 +1,126 @@
from typing import Any, Dict

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tokenizers import BertWordPieceTokenizer
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .utils import TEAM, BertWrapper, CLIPVisionWrapper, CrossLayer

logger = get_logger()

__all__ = ['TEAMForMultiModalSimilarity']


@MODELS.register_module(Tasks.multi_modal_similarity, module_name=Models.team)
class TEAMForMultiModalSimilarity(TorchModel):

def __init__(self, model_dir, device_id=0, *args, **kwargs):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)

text_model = BertWrapper(
config_json='{}/text_config.json'.format(model_dir),
feat_dim=768,
token_dim=1024)
text_model.bert.cls = None
image_model = CLIPVisionWrapper()

self.model = TEAM(
text_model,
image_model,
pretrained='{}/{}'.format(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE))
self.model.eval()

self.device_id = device_id
if self.device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
self.device_id = -1
logger.info('Use CPU for inference')

self.text_tokenizer = BertWordPieceTokenizer(
'{}/{}'.format(model_dir, ModelFile.VOCAB_FILE), lowercase=False)
self.text_tokenizer.enable_truncation(max_length=30)

norm_op = Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
self.img_preprocessor = Compose([
Resize((224, 224), interpolation=Image.BICUBIC),
ToTensor(), norm_op
])

def tokenize_text(self, text_str):
tokens = self.text_tokenizer.encode(text_str)
max_tokens = 30
text_ids_tensor = torch.zeros((1, max_tokens)).long()
text_mask_tensor = torch.zeros((1, max_tokens))
text_ids, text_mask = tokens.ids, tokens.attention_mask
text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids)
text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask)
return text_ids_tensor, text_mask_tensor

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
if 'img' in input and input['img'] is not None:
input_img = input['img']
input_img = LoadImage.convert_to_img(input_img)
img_tensor = self.img_preprocessor(input_img)[None, ...]

if self.device_id >= 0:
img_tensor = img_tensor.to('cuda:{}'.format(
self.device_id))
_, _, image_feature, image_tensors = self.model.get_feature(
None, None, img_tensor)
image_feature = image_feature.cpu().numpy()
else:
image_feature, image_tensors = None, None

if 'text' in input and input['text'] is not None:
text_str = input['text']
if isinstance(text_str, str):
text_ids_tensor, text_mask_tensor = self.tokenize_text(
text_str)
else:
raise TypeError(
f'text should be str, but got {type(text_str)}')

if self.device_id >= 0:
text_ids_tensor = text_ids_tensor.to('cuda:{}'.format(
self.device_id))
text_mask_tensor = text_mask_tensor.to('cuda:{}'.format(
self.device_id))
text_feature, text_tensors, _, _ = self.model.get_feature(
text_ids_tensor, text_mask_tensor, None)
text_feature = text_feature.cpu().numpy()
else:
text_tensors, text_mask_tensor = None, None

if text_tensors is not None and text_mask_tensor is not None and image_tensors is not None:
score = self.model.get_cross_score(text_tensors,
text_mask_tensor,
image_tensors)[0].item()
else:
score = None
output = {
OutputKeys.IMG_EMBEDDING: image_feature,
OutputKeys.TEXT_EMBEDDING: text_feature,
OutputKeys.SCORES: score
}
return output

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 326
- 0
modelscope/models/multi_modal/team/utils.py View File

@@ -0,0 +1,326 @@
""" Generative Multimodal Model
Base Transformer code is adapted from https://github.com/openai/CLIP/,
originally MIT License, Copyright (c) 2021 OpenAI,
"""
from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch import nn
from transformers import BertConfig, BertForMaskedLM


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)


class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):

def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


class Transformer(nn.Module):

def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
use_gc=False):
super().__init__()
self.use_gc = use_gc
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])

def forward(self, x: torch.Tensor):
if self.use_gc:
for each_block in self.resblocks:
x = checkpoint.checkpoint(each_block, x)
return x
else:
return self.resblocks(x)


class VisionTransformer(nn.Module):

def __init__(self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
use_gc=False):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)

scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads, use_gc=use_gc)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_embedding = self.class_embedding.to(x.dtype) + \
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([class_embedding, x],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :])

if self.proj is not None:
x = x @ self.proj

return x


class CLIPVisionWrapper(nn.Module):

def __init__(self, ):
super().__init__()
self.vision_transformer = VisionTransformer(
input_resolution=224,
patch_size=14,
width=1024,
layers=24,
heads=16,
output_dim=768)

def forward(self, x):
x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([class_embedding, x],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.vision_transformer.positional_embedding.to(x.dtype)
x = self.vision_transformer.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.vision_transformer.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

x_tensor = x.clone()
x = self.vision_transformer.ln_post(x[:, 0, :])

if self.vision_transformer.proj is not None:
x = x @ self.vision_transformer.proj

return x, x_tensor


class BertWrapper(nn.Module):

def __init__(self, config_json, feat_dim, token_dim):
super(BertWrapper, self).__init__()
bert_config = BertConfig.from_json_file(config_json)
self.bert = BertForMaskedLM(bert_config).bert

self.projector = nn.Linear(768, feat_dim, bias=False)
self.projector_token_embeds = nn.Linear(768, token_dim)

def forward(self, input_ids, attention_mask):
trans_features = {
'input_ids': input_ids,
'attention_mask': attention_mask
}
output_states = self.bert(**trans_features, return_dict=False)
output_tokens = output_states[0]

cls_tokens = output_tokens[:, 0, :] # CLS token is first token

return self.projector(cls_tokens), self.projector_token_embeds(
output_tokens)


class Mlp(nn.Module):

def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class CrossLayer(nn.Module):

def __init__(self, feat_dim, mlp_ratio):
super(CrossLayer, self).__init__()
self.norm1 = nn.LayerNorm(feat_dim)
self.norm2 = nn.LayerNorm(feat_dim)
self.norm3 = nn.LayerNorm(feat_dim)

self.self_attn = nn.MultiheadAttention(
embed_dim=feat_dim, num_heads=16)
self.cross_attn = nn.MultiheadAttention(
embed_dim=feat_dim, num_heads=16)
self.ffn = Mlp(
in_features=feat_dim,
hidden_features=feat_dim * mlp_ratio,
drop=0.1)

self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.dropout3 = nn.Dropout(0.1)

def forward(self, text_tensors, text_masks, image_tensors,
retrieved_tensors):
retrieved_tensors_res = self.norm1(retrieved_tensors)
retrieved_tensors_res = self.self_attn(
(text_tensors + retrieved_tensors_res).permute(1, 0, 2),
(text_tensors + retrieved_tensors_res).permute(1, 0, 2),
retrieved_tensors_res.permute(1, 0, 2),
key_padding_mask=(text_masks == 0),
)[0].permute(1, 0, 2)
retrieved_tensors = retrieved_tensors + self.dropout1(
retrieved_tensors_res)

retrieved_tensors_res = self.norm2(retrieved_tensors)
retrieved_tensors_res = self.cross_attn(
(text_tensors + retrieved_tensors_res).permute(1, 0, 2),
image_tensors.permute(1, 0, 2),
image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2)
retrieved_tensors = retrieved_tensors + self.dropout2(
retrieved_tensors_res)

retrieved_tensors_res = self.norm3(retrieved_tensors)
retrieved_tensors = retrieved_tensors + self.dropout3(
self.ffn(retrieved_tensors_res))

return retrieved_tensors


class TEAM(nn.Module):

def __init__(self, text_model, image_model, pretrained):
super(TEAM, self).__init__()
self.text_model = text_model
self.image_model = image_model

self.cross_model = nn.ModuleList(
[CrossLayer(feat_dim=1024, mlp_ratio=2)])

self.image_tensor_fc = nn.Linear(1024, 768)
self.text_tensor_fc = nn.Linear(1024, 768)

params = torch.load(pretrained, 'cpu')
self.load_state_dict(params, strict=True)

def get_feature(self, text_data=None, text_mask=None, img_tensor=None):
if text_data is not None:
text_feature, text_tensors = self.text_model(text_data, text_mask)
text_feature = F.normalize(text_feature, p=2.0, dim=1)
else:
text_feature, text_tensors = None, None

if img_tensor is not None:
image_feature, image_tensors = self.image_model(img_tensor)
image_feature = F.normalize(image_feature, p=2.0, dim=1)
else:
image_feature, image_tensors = None, None

return text_feature, text_tensors, image_feature, image_tensors

def get_cross_score(self, text_tensors, text_mask, image_tensors):
retrieved_tensors = torch.zeros_like(text_tensors)
pair_score_list = []
text_tensors_proj = self.text_tensor_fc(text_tensors)
text_mask_float = text_mask.type(text_tensors_proj.dtype)
for each_cross_model in self.cross_model:
retrieved_tensors = each_cross_model(text_tensors, text_mask,
image_tensors,
retrieved_tensors)
retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors)

pair_score = torch.sum(
F.normalize(retrieved_tensors_proj, p=2.0, dim=2)
* F.normalize(text_tensors_proj, p=2.0, dim=2),
dim=2)
pair_score_reduced = torch.sum(
pair_score * text_mask_float, dim=1) / torch.clamp(
torch.sum(text_mask_float, dim=1), min=1.0)
pair_score_list.append(pair_score_reduced)
return pair_score_list

+ 9
- 0
modelscope/outputs.py View File

@@ -499,6 +499,15 @@ TASK_OUTPUTS = {
Tasks.generative_multi_modal_embedding:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION],

# multi-modal similarity result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# "text_embedding": np.array with shape [1, D],
# "similarity": float
# }
Tasks.multi_modal_similarity:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES],

# VQA result for a sample
# {"text": "this is a text answser. "}
Tasks.visual_question_answering: [OutputKeys.TEXT],


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -79,6 +79,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
(Pipelines.generative_multi_modal_embedding,
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
),
Tasks.multi_modal_similarity:
(Pipelines.multi_modal_similarity,
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'),
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),


+ 31
- 0
modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py View File

@@ -0,0 +1,31 @@
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.multi_modal_similarity, module_name=Pipelines.multi_modal_similarity)
class TEAMMultiModalSimilarityPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a multimodal similarity pipeline
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -117,6 +117,7 @@ class MultiModalTasks(object):
text_to_image_synthesis = 'text-to-image-synthesis'
multi_modal_embedding = 'multi-modal-embedding'
generative_multi_modal_embedding = 'generative-multi-modal-embedding'
multi_modal_similarity = 'multi-modal-similarity'
visual_question_answering = 'visual-question-answering'
visual_entailment = 'visual-entailment'
video_multi_modal_embedding = 'video-multi-modal-embedding'


+ 42
- 0
tests/pipelines/test_multi_modal_similarity.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class MultiModalSimilarityTest(unittest.TestCase):
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
test_input = {
'img': 'data/test/images/generative_multimodal.jpg',
'text': '起居室照片'
}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
multi_modal_similarity_pipeline = pipeline(
Tasks.multi_modal_similarity, model=self.model_id)
output = multi_modal_similarity_pipeline(self.test_input)
print(output)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
multi_modal_similarity_pipeline = pipeline(
task=Tasks.multi_modal_similarity)
output = multi_modal_similarity_pipeline(self.test_input)
print(output)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
multi_modal_similarity_pipeline = pipeline(
task=Tasks.multi_modal_similarity, model=model)
output = multi_modal_similarity_pipeline(self.test_input)
print(output)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save