增加中文图文特征模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9157786
* add multi-modal-feature
* 修改code review中的问题
master
| @@ -24,6 +24,7 @@ class Models(object): | |||
| # multi-modal models | |||
| ofa = 'ofa' | |||
| clip = 'clip-multi-modal-embedding' | |||
| class Pipelines(object): | |||
| @@ -55,6 +56,7 @@ class Pipelines(object): | |||
| # multi-modal tasks | |||
| image_caption = 'image-caption' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| class Trainers(object): | |||
| @@ -4,5 +4,5 @@ from .audio.tts.am import SambertNetHifi16k | |||
| from .audio.tts.vocoder import Hifigan16k | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .multi_model import OfaForImageCaptioning | |||
| from .multi_modal import OfaForImageCaptioning | |||
| from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity | |||
| @@ -1 +1,2 @@ | |||
| from .clip.clip_model import CLIPForMultiModalEmbedding | |||
| from .image_captioning_model import OfaForImageCaptioning | |||
| @@ -0,0 +1,26 @@ | |||
| import torch.nn as nn | |||
| from transformers import BertConfig, BertForMaskedLM | |||
| class TextTransformer(nn.Module): | |||
| def __init__(self, config_dict, feat_dim=768): | |||
| super(TextTransformer, self).__init__() | |||
| bert_config = BertConfig.from_dict(config_dict) | |||
| self.bert = BertForMaskedLM(bert_config).bert | |||
| self.projector = nn.Linear( | |||
| bert_config.hidden_size, feat_dim, bias=False) | |||
| def forward(self, input_ids, attention_mask): | |||
| trans_features = { | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask | |||
| } | |||
| output_states = self.bert(**trans_features, return_dict=False) | |||
| output_tokens = output_states[0] | |||
| cls_tokens = output_tokens[:, 0, :] | |||
| return self.projector(cls_tokens) | |||
| @@ -0,0 +1,158 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from PIL import Image | |||
| from tokenizers import BertWordPieceTokenizer | |||
| from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.multi_modal.clip.clip_bert import TextTransformer | |||
| from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['CLIPForMultiModalEmbedding'] | |||
| class CLIPModel(nn.Module): | |||
| def __init__(self, model_dir): | |||
| super(CLIPModel, self).__init__() | |||
| # including vision config and text config | |||
| model_config = json.load( | |||
| open('{}/encoder_config.json'.format(model_dir))) | |||
| # vision encoder | |||
| vision_config = model_config['vision_config'] | |||
| self.img_size = vision_config['input_resolution'] | |||
| self.vision_encoder = VisionTransformer( | |||
| input_resolution=self.img_size, | |||
| patch_size=vision_config['patch_size'], | |||
| width=vision_config['width'], | |||
| layers=vision_config['layers'], | |||
| heads=vision_config['heads'], | |||
| output_dim=vision_config['feat_dim']) | |||
| # text encoder | |||
| text_config = model_config['text_config'] | |||
| self.text_encoder = TextTransformer( | |||
| text_config['bert_config'], feat_dim=text_config['feat_dim']) | |||
| def forward(self, input_data, input_type): | |||
| if input_type == 'img': | |||
| img_embedding = self.vision_encoder(input_data) | |||
| img_embedding = F.normalize(img_embedding, p=2.0, dim=1) | |||
| return img_embedding | |||
| elif input_type == 'text': | |||
| text_ids_tensor, text_mask_tensor = input_data | |||
| text_embedding = self.text_encoder(text_ids_tensor, | |||
| text_mask_tensor) | |||
| text_embedding = F.normalize(text_embedding, p=2.0, dim=1) | |||
| return text_embedding | |||
| else: | |||
| raise ValueError('Unknown input type') | |||
| @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) | |||
| class CLIPForMultiModalEmbedding(Model): | |||
| def __init__(self, model_dir, device_id=-1): | |||
| super().__init__(model_dir=model_dir, device_id=device_id) | |||
| self.clip_model = CLIPModel(model_dir=model_dir) | |||
| pretrained_params = torch.load( | |||
| '{}/pytorch_model.bin'.format(model_dir), 'cpu') | |||
| self.clip_model.load_state_dict(pretrained_params) | |||
| self.clip_model.eval() | |||
| self.device_id = device_id | |||
| if self.device_id >= 0: | |||
| self.clip_model.to('cuda:{}'.format(self.device_id)) | |||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||
| else: | |||
| logger.info('Use CPU for inference') | |||
| # image preprocessor | |||
| norm_op = Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)) | |||
| self.img_preprocessor = Compose([ | |||
| Resize((self.clip_model.img_size, self.clip_model.img_size), | |||
| interpolation=Image.BICUBIC), | |||
| ToTensor(), norm_op | |||
| ]) | |||
| # text tokenizer | |||
| vocab_path = '{}/vocab.txt'.format(model_dir) | |||
| self.text_tokenizer = BertWordPieceTokenizer( | |||
| vocab_path, lowercase=False) | |||
| self.text_tokenizer.enable_truncation(max_length=30) | |||
| def tokenize_text(self, text_str): | |||
| tokens = self.text_tokenizer.encode(text_str) | |||
| max_tokens = 30 | |||
| text_ids_tensor = torch.zeros((1, max_tokens)).long() | |||
| text_mask_tensor = torch.zeros((1, max_tokens)) | |||
| text_ids, text_mask = tokens.ids, tokens.attention_mask | |||
| text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids) | |||
| text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask) | |||
| return text_ids_tensor, text_mask_tensor | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| output = {'img_embedding': None, 'text_embedding': None} | |||
| if 'img' in input and input['img'] is not None: | |||
| input_img = input['img'] | |||
| if isinstance(input_img, Image.Image): | |||
| img_tensor = self.img_preprocessor(input_img)[None, ...] | |||
| elif isinstance(input_img, np.ndarray): | |||
| if len(input_img.shape) == 2: | |||
| input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) | |||
| input_img = input_img[:, :, ::-1] # in rgb order | |||
| input_img = Image.fromarray( | |||
| input_img.astype('uint8')).convert('RGB') | |||
| img_tensor = self.img_preprocessor(input_img)[None, ...] | |||
| else: | |||
| raise TypeError( | |||
| f'img should be either PIL.Image or np.array, but got {type(input_img)}' | |||
| ) | |||
| if self.device_id >= 0: | |||
| img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) | |||
| img_embedding = self.clip_model( | |||
| input_data=img_tensor, input_type='img') | |||
| output['img_embedding'] = img_embedding.data.cpu().numpy() | |||
| if 'text' in input and input['text'] is not None: | |||
| text_str = input['text'] | |||
| if isinstance(text_str, str): | |||
| text_ids_tensor, text_mask_tensor = self.tokenize_text( | |||
| text_str) | |||
| else: | |||
| raise TypeError( | |||
| f'text should be str, but got {type(text_str)}') | |||
| if self.device_id >= 0: | |||
| text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( | |||
| self.device_id)) | |||
| text_mask_tensor = text_mask_tensor.to('cuda:{}'.format( | |||
| self.device_id)) | |||
| text_embedding = self.clip_model( | |||
| input_data=(text_ids_tensor, text_mask_tensor), | |||
| input_type='text') | |||
| output['text_embedding'] = text_embedding.data.cpu().numpy() | |||
| return output | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,121 @@ | |||
| # Copyright 2021 The OpenAI CLIP Authors. All rights reserved. | |||
| from collections import OrderedDict | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| class LayerNorm(nn.LayerNorm): | |||
| """Subclass torch's LayerNorm to handle fp16.""" | |||
| def forward(self, x: torch.Tensor): | |||
| orig_type = x.dtype | |||
| ret = super().forward(x.type(torch.float32)) | |||
| return ret.type(orig_type) | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, | |||
| d_model: int, | |||
| n_head: int, | |||
| attn_mask: torch.Tensor = None): | |||
| super().__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.attn_mask = attn_mask | |||
| def attention(self, x: torch.Tensor): | |||
| self.attn_mask = self.attn_mask.to( | |||
| dtype=x.dtype, | |||
| device=x.device) if self.attn_mask is not None else None | |||
| return self.attn( | |||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
| def forward(self, x: torch.Tensor): | |||
| x = x + self.attention(self.ln_1(x)) | |||
| x = x + self.mlp(self.ln_2(x)) | |||
| return x | |||
| class Transformer(nn.Module): | |||
| def __init__(self, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| attn_mask: torch.Tensor = None): | |||
| super().__init__() | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential(*[ | |||
| ResidualAttentionBlock(width, heads, attn_mask) | |||
| for _ in range(layers) | |||
| ]) | |||
| def forward(self, x: torch.Tensor): | |||
| return self.resblocks(x) | |||
| class VisionTransformer(nn.Module): | |||
| def __init__(self, input_resolution: int, patch_size: int, width: int, | |||
| layers: int, heads: int, output_dim: int): | |||
| super().__init__() | |||
| self.input_resolution = input_resolution | |||
| self.output_dim = output_dim | |||
| self.conv1 = nn.Conv2d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=patch_size, | |||
| stride=patch_size, | |||
| bias=False) | |||
| scale = width**-0.5 | |||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
| (input_resolution // patch_size)**2 + 1, width)) | |||
| self.ln_pre = LayerNorm(width) | |||
| self.transformer = Transformer(width, layers, heads) | |||
| self.ln_post = LayerNorm(width) | |||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
| def forward(self, x: torch.Tensor): | |||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| class_embeddings = self.class_embedding.to(x.dtype) + \ | |||
| torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
| x = torch.cat([class_embeddings, x], dim=1) | |||
| x = x + self.positional_embedding.to(x.dtype) | |||
| x = self.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| x = self.ln_post(x[:, 0, :]) | |||
| if self.proj is not None: | |||
| x = x @ self.proj | |||
| return x | |||
| @@ -21,8 +21,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.sentence_similarity: | |||
| (Pipelines.sentence_similarity, | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.image_matting: | |||
| (Pipelines.image_matting, 'damo/cv_unet_image-matting'), | |||
| Tasks.image_matting: (Pipelines.image_matting, | |||
| 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: (Pipelines.text_generation, | |||
| @@ -37,6 +37,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
| Tasks.action_recognition: (Pipelines.action_recognition, | |||
| 'damo/cv_TAdaConv_action-recognition'), | |||
| Tasks.multi_modal_embedding: | |||
| (Pipelines.multi_modal_embedding, | |||
| 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding') | |||
| } | |||
| @@ -1 +1,2 @@ | |||
| from .image_captioning_pipeline import ImageCaptionPipeline | |||
| from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | |||
| @@ -0,0 +1,34 @@ | |||
| from typing import Any, Dict, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Model, Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | |||
| class MultiModalEmbeddingPipeline(Pipeline): | |||
| def __init__(self, model: str, device_id: int = -1): | |||
| if isinstance(model, str): | |||
| pipe_model = Model.from_pretrained(model) | |||
| elif isinstance(model, Model): | |||
| pipe_model = model | |||
| else: | |||
| raise NotImplementedError('model must be a single str') | |||
| super().__init__(model=pipe_model) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| return self.model(input) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -117,6 +117,13 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.image_captioning: ['caption'], | |||
| # multi-modal embedding result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| # "text_embedding": np.array with shape [1, D] | |||
| # } | |||
| Tasks.multi_modal_embedding: ['img_embedding', 'text_embedding'], | |||
| # visual grounding result for single sample | |||
| # { | |||
| # "boxes": [ | |||
| @@ -57,6 +57,7 @@ class Tasks(object): | |||
| image_captioning = 'image-captioning' | |||
| visual_grounding = 'visual-grounding' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| class InputFields(object): | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class MultiModalEmbeddingTest(unittest.TestCase): | |||
| model_id = 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding' | |||
| test_text = {'text': '一张风景图'} | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| pipe_line_multi_modal_embedding = pipeline( | |||
| Tasks.multi_modal_embedding, model=self.model_id) | |||
| test_str_embedding = pipe_line_multi_modal_embedding( | |||
| self.test_text)['text_embedding'] | |||
| print(np.sum(np.abs(test_str_embedding))) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| pipe_line_multi_modal_embedding = pipeline( | |||
| task=Tasks.multi_modal_embedding, model=model) | |||
| test_str_embedding = pipe_line_multi_modal_embedding( | |||
| self.test_text)['text_embedding'] | |||
| print(np.sum(np.abs(test_str_embedding))) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipe_line_multi_modal_embedding = pipeline( | |||
| task=Tasks.multi_modal_embedding, model=self.model_id) | |||
| test_str_embedding = pipe_line_multi_modal_embedding( | |||
| self.test_text)['text_embedding'] | |||
| print(np.sum(np.abs(test_str_embedding))) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipe_line_multi_modal_embedding = pipeline( | |||
| task=Tasks.multi_modal_embedding) | |||
| test_str_embedding = pipe_line_multi_modal_embedding( | |||
| self.test_text)['text_embedding'] | |||
| print(np.sum(np.abs(test_str_embedding))) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||