From 8c9348de2cae660875cc6a728bea64e56eb0c73f Mon Sep 17 00:00:00 2001 From: "eniac.xcw" Date: Tue, 30 Aug 2022 10:07:39 +0800 Subject: [PATCH] [to #42322933]add team model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9908976 --- modelscope/metainfo.py | 2 + modelscope/models/multi_modal/__init__.py | 2 + .../models/multi_modal/team/__init__.py | 1 + .../models/multi_modal/team/team_model.py | 126 +++++++ modelscope/models/multi_modal/team/utils.py | 326 ++++++++++++++++++ modelscope/outputs.py | 9 + modelscope/pipelines/builder.py | 3 + .../team_multi_modal_similarity_pipeline.py | 31 ++ modelscope/utils/constant.py | 1 + .../pipelines/test_multi_modal_similarity.py | 42 +++ 10 files changed, 543 insertions(+) create mode 100644 modelscope/models/multi_modal/team/__init__.py create mode 100644 modelscope/models/multi_modal/team/team_model.py create mode 100644 modelscope/models/multi_modal/team/utils.py create mode 100644 modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py create mode 100644 tests/pipelines/test_multi_modal_similarity.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d9e53ca7..58fd4f46 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -59,6 +59,7 @@ class Models(object): gemm = 'gemm-generative-multi-modal' mplug = 'mplug' diffusion = 'diffusion-text-to-image-synthesis' + team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' @@ -166,6 +167,7 @@ class Pipelines(object): visual_question_answering = 'visual-question-answering' visual_grounding = 'visual-grounding' visual_entailment = 'visual-entailment' + multi_modal_similarity = 'multi-modal-similarity' text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 112b3a58..9219a281 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .clip import CLIPForMultiModalEmbedding from .gemm import GEMMForMultiModalEmbedding + from .team import TEAMForMultiModalSimilarity from .diffusion import DiffusionForTextToImageSynthesis from .mmr import VideoCLIPForMultiModalEmbedding from .mplug_for_all_tasks import MPlugForAllTasks @@ -19,6 +20,7 @@ else: 'clip': ['CLIPForMultiModalEmbedding'], 'diffusion': ['DiffusionForTextToImageSynthesis'], 'gemm': ['GEMMForMultiModalEmbedding'], + 'team': ['TEAMForMultiModalSimilarity'], 'mmr': ['VideoCLIPForMultiModalEmbedding'], 'mplug_for_all_tasks': ['MPlugForAllTasks'], 'ofa_for_all_tasks': ['OfaForAllTasks'], diff --git a/modelscope/models/multi_modal/team/__init__.py b/modelscope/models/multi_modal/team/__init__.py new file mode 100644 index 00000000..0597040c --- /dev/null +++ b/modelscope/models/multi_modal/team/__init__.py @@ -0,0 +1 @@ +from .team_model import TEAMForMultiModalSimilarity diff --git a/modelscope/models/multi_modal/team/team_model.py b/modelscope/models/multi_modal/team/team_model.py new file mode 100644 index 00000000..4aa77e17 --- /dev/null +++ b/modelscope/models/multi_modal/team/team_model.py @@ -0,0 +1,126 @@ +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from tokenizers import BertWordPieceTokenizer +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .utils import TEAM, BertWrapper, CLIPVisionWrapper, CrossLayer + +logger = get_logger() + +__all__ = ['TEAMForMultiModalSimilarity'] + + +@MODELS.register_module(Tasks.multi_modal_similarity, module_name=Models.team) +class TEAMForMultiModalSimilarity(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + text_model = BertWrapper( + config_json='{}/text_config.json'.format(model_dir), + feat_dim=768, + token_dim=1024) + text_model.bert.cls = None + image_model = CLIPVisionWrapper() + + self.model = TEAM( + text_model, + image_model, + pretrained='{}/{}'.format(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.eval() + + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + self.text_tokenizer = BertWordPieceTokenizer( + '{}/{}'.format(model_dir, ModelFile.VOCAB_FILE), lowercase=False) + self.text_tokenizer.enable_truncation(max_length=30) + + norm_op = Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + self.img_preprocessor = Compose([ + Resize((224, 224), interpolation=Image.BICUBIC), + ToTensor(), norm_op + ]) + + def tokenize_text(self, text_str): + tokens = self.text_tokenizer.encode(text_str) + max_tokens = 30 + text_ids_tensor = torch.zeros((1, max_tokens)).long() + text_mask_tensor = torch.zeros((1, max_tokens)) + text_ids, text_mask = tokens.ids, tokens.attention_mask + text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids) + text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask) + return text_ids_tensor, text_mask_tensor + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + if 'img' in input and input['img'] is not None: + input_img = input['img'] + input_img = LoadImage.convert_to_img(input_img) + img_tensor = self.img_preprocessor(input_img)[None, ...] + + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format( + self.device_id)) + _, _, image_feature, image_tensors = self.model.get_feature( + None, None, img_tensor) + image_feature = image_feature.cpu().numpy() + else: + image_feature, image_tensors = None, None + + if 'text' in input and input['text'] is not None: + text_str = input['text'] + if isinstance(text_str, str): + text_ids_tensor, text_mask_tensor = self.tokenize_text( + text_str) + else: + raise TypeError( + f'text should be str, but got {type(text_str)}') + + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + text_mask_tensor = text_mask_tensor.to('cuda:{}'.format( + self.device_id)) + text_feature, text_tensors, _, _ = self.model.get_feature( + text_ids_tensor, text_mask_tensor, None) + text_feature = text_feature.cpu().numpy() + else: + text_tensors, text_mask_tensor = None, None + + if text_tensors is not None and text_mask_tensor is not None and image_tensors is not None: + score = self.model.get_cross_score(text_tensors, + text_mask_tensor, + image_tensors)[0].item() + else: + score = None + output = { + OutputKeys.IMG_EMBEDDING: image_feature, + OutputKeys.TEXT_EMBEDDING: text_feature, + OutputKeys.SCORES: score + } + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/models/multi_modal/team/utils.py b/modelscope/models/multi_modal/team/utils.py new file mode 100644 index 00000000..3b3e394e --- /dev/null +++ b/modelscope/models/multi_modal/team/utils.py @@ -0,0 +1,326 @@ +""" Generative Multimodal Model +Base Transformer code is adapted from https://github.com/openai/CLIP/, +originally MIT License, Copyright (c) 2021 OpenAI, +""" +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn +from transformers import BertConfig, BertForMaskedLM + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + use_gc=False): + super().__init__() + self.use_gc = use_gc + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + if self.use_gc: + for each_block in self.resblocks: + x = checkpoint.checkpoint(each_block, x) + return x + else: + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + use_gc=False): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_embedding = self.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_embedding, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIPVisionWrapper(nn.Module): + + def __init__(self, ): + super().__init__() + self.vision_transformer = VisionTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768) + + def forward(self, x): + x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_embedding, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.vision_transformer.positional_embedding.to(x.dtype) + x = self.vision_transformer.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.vision_transformer.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x_tensor = x.clone() + x = self.vision_transformer.ln_post(x[:, 0, :]) + + if self.vision_transformer.proj is not None: + x = x @ self.vision_transformer.proj + + return x, x_tensor + + +class BertWrapper(nn.Module): + + def __init__(self, config_json, feat_dim, token_dim): + super(BertWrapper, self).__init__() + bert_config = BertConfig.from_json_file(config_json) + self.bert = BertForMaskedLM(bert_config).bert + + self.projector = nn.Linear(768, feat_dim, bias=False) + self.projector_token_embeds = nn.Linear(768, token_dim) + + def forward(self, input_ids, attention_mask): + trans_features = { + 'input_ids': input_ids, + 'attention_mask': attention_mask + } + output_states = self.bert(**trans_features, return_dict=False) + output_tokens = output_states[0] + + cls_tokens = output_tokens[:, 0, :] # CLS token is first token + + return self.projector(cls_tokens), self.projector_token_embeds( + output_tokens) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CrossLayer(nn.Module): + + def __init__(self, feat_dim, mlp_ratio): + super(CrossLayer, self).__init__() + self.norm1 = nn.LayerNorm(feat_dim) + self.norm2 = nn.LayerNorm(feat_dim) + self.norm3 = nn.LayerNorm(feat_dim) + + self.self_attn = nn.MultiheadAttention( + embed_dim=feat_dim, num_heads=16) + self.cross_attn = nn.MultiheadAttention( + embed_dim=feat_dim, num_heads=16) + self.ffn = Mlp( + in_features=feat_dim, + hidden_features=feat_dim * mlp_ratio, + drop=0.1) + + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + self.dropout3 = nn.Dropout(0.1) + + def forward(self, text_tensors, text_masks, image_tensors, + retrieved_tensors): + retrieved_tensors_res = self.norm1(retrieved_tensors) + retrieved_tensors_res = self.self_attn( + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + retrieved_tensors_res.permute(1, 0, 2), + key_padding_mask=(text_masks == 0), + )[0].permute(1, 0, 2) + retrieved_tensors = retrieved_tensors + self.dropout1( + retrieved_tensors_res) + + retrieved_tensors_res = self.norm2(retrieved_tensors) + retrieved_tensors_res = self.cross_attn( + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + image_tensors.permute(1, 0, 2), + image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2) + retrieved_tensors = retrieved_tensors + self.dropout2( + retrieved_tensors_res) + + retrieved_tensors_res = self.norm3(retrieved_tensors) + retrieved_tensors = retrieved_tensors + self.dropout3( + self.ffn(retrieved_tensors_res)) + + return retrieved_tensors + + +class TEAM(nn.Module): + + def __init__(self, text_model, image_model, pretrained): + super(TEAM, self).__init__() + self.text_model = text_model + self.image_model = image_model + + self.cross_model = nn.ModuleList( + [CrossLayer(feat_dim=1024, mlp_ratio=2)]) + + self.image_tensor_fc = nn.Linear(1024, 768) + self.text_tensor_fc = nn.Linear(1024, 768) + + params = torch.load(pretrained, 'cpu') + self.load_state_dict(params, strict=True) + + def get_feature(self, text_data=None, text_mask=None, img_tensor=None): + if text_data is not None: + text_feature, text_tensors = self.text_model(text_data, text_mask) + text_feature = F.normalize(text_feature, p=2.0, dim=1) + else: + text_feature, text_tensors = None, None + + if img_tensor is not None: + image_feature, image_tensors = self.image_model(img_tensor) + image_feature = F.normalize(image_feature, p=2.0, dim=1) + else: + image_feature, image_tensors = None, None + + return text_feature, text_tensors, image_feature, image_tensors + + def get_cross_score(self, text_tensors, text_mask, image_tensors): + retrieved_tensors = torch.zeros_like(text_tensors) + pair_score_list = [] + text_tensors_proj = self.text_tensor_fc(text_tensors) + text_mask_float = text_mask.type(text_tensors_proj.dtype) + for each_cross_model in self.cross_model: + retrieved_tensors = each_cross_model(text_tensors, text_mask, + image_tensors, + retrieved_tensors) + retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors) + + pair_score = torch.sum( + F.normalize(retrieved_tensors_proj, p=2.0, dim=2) + * F.normalize(text_tensors_proj, p=2.0, dim=2), + dim=2) + pair_score_reduced = torch.sum( + pair_score * text_mask_float, dim=1) / torch.clamp( + torch.sum(text_mask_float, dim=1), min=1.0) + pair_score_list.append(pair_score_reduced) + return pair_score_list diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 622d9034..1c42a5f3 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -499,6 +499,15 @@ TASK_OUTPUTS = { Tasks.generative_multi_modal_embedding: [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], + # multi-modal similarity result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D], + # "similarity": float + # } + Tasks.multi_modal_similarity: + [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], + # VQA result for a sample # {"text": "this is a text answser. "} Tasks.visual_question_answering: [OutputKeys.TEXT], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index f8f679e6..53f55b06 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -79,6 +79,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { (Pipelines.generative_multi_modal_embedding, 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' ), + Tasks.multi_modal_similarity: + (Pipelines.multi_modal_similarity, + 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'), Tasks.visual_question_answering: (Pipelines.visual_question_answering, 'damo/mplug_visual-question-answering_coco_large_en'), diff --git a/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py b/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py new file mode 100644 index 00000000..7d3ffed3 --- /dev/null +++ b/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py @@ -0,0 +1,31 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.multi_modal_similarity, module_name=Pipelines.multi_modal_similarity) +class TEAMMultiModalSimilarityPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a multimodal similarity pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2141a012..6d419a7e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -117,6 +117,7 @@ class MultiModalTasks(object): text_to_image_synthesis = 'text-to-image-synthesis' multi_modal_embedding = 'multi-modal-embedding' generative_multi_modal_embedding = 'generative-multi-modal-embedding' + multi_modal_similarity = 'multi-modal-similarity' visual_question_answering = 'visual-question-answering' visual_entailment = 'visual-entailment' video_multi_modal_embedding = 'video-multi-modal-embedding' diff --git a/tests/pipelines/test_multi_modal_similarity.py b/tests/pipelines/test_multi_modal_similarity.py new file mode 100644 index 00000000..d1d6a7a8 --- /dev/null +++ b/tests/pipelines/test_multi_modal_similarity.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiModalSimilarityTest(unittest.TestCase): + model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' + test_input = { + 'img': 'data/test/images/generative_multimodal.jpg', + 'text': '起居室照片' + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + multi_modal_similarity_pipeline = pipeline( + Tasks.multi_modal_similarity, model=self.model_id) + output = multi_modal_similarity_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + multi_modal_similarity_pipeline = pipeline( + task=Tasks.multi_modal_similarity) + output = multi_modal_similarity_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + multi_modal_similarity_pipeline = pipeline( + task=Tasks.multi_modal_similarity, model=model) + output = multi_modal_similarity_pipeline(self.test_input) + print(output) + + +if __name__ == '__main__': + unittest.main()