diff --git a/data/test/images/generative_multimodal.jpg b/data/test/images/generative_multimodal.jpg new file mode 100644 index 00000000..b7b32939 --- /dev/null +++ b/data/test/images/generative_multimodal.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24b78db10990c809380508b962decb53cb16db582135cb3c7d56c48f71d5ceb8 +size 39683 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 31f37e76..d4bb64aa 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -31,6 +31,7 @@ class Models(object): # multi-modal models ofa = 'ofa' clip = 'clip-multi-modal-embedding' + gemm = 'gemm-generative-multi-modal' mplug = 'mplug' imagen = 'imagen-text-to-image-synthesis' @@ -95,6 +96,7 @@ class Pipelines(object): # multi-modal tasks image_captioning = 'image-captioning' multi_modal_embedding = 'multi-modal-embedding' + generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' text_to_image_synthesis = 'text-to-image-synthesis' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 1f60878b..89db0290 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1,4 +1,5 @@ from .clip.clip_model import CLIPForMultiModalEmbedding +from .gemm.gemm_model import GEMMForMultiModalEmbedding from .imagen.imagen_model import ImagenForTextToImageSynthesis from .mplug_for_visual_question_answering import \ MPlugForVisualQuestionAnswering diff --git a/modelscope/models/multi_modal/gemm/__init__.py b/modelscope/models/multi_modal/gemm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/gemm/gemm_base.py b/modelscope/models/multi_modal/gemm/gemm_base.py new file mode 100644 index 00000000..26eea0d5 --- /dev/null +++ b/modelscope/models/multi_modal/gemm/gemm_base.py @@ -0,0 +1,550 @@ +""" Generative Multimodal Model +Base modules are adapted from https://github.com/openai/CLIP/, +originally MIT License, Copyright (c) 2021 OpenAI, +and adapted from https://github.com/lucidrains/CoCa-pytorch/, +originally MIT License, Copyright (c) 2022 Phil Wang. +""" + +import os +from collections import OrderedDict +from typing import Tuple, Union + +import json +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm + +from modelscope.models.multi_modal.gemm.tokenizer import (SimpleTokenizer, + clip_tokenize) + + +class Bottleneck(nn.Module): + """ ResNet style bottleneck module + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + + +class QuickGELU(nn.Module): + """ A quick version of GELU module + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """ Multihead attention block with residual link + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + attn_mask = self.attn_mask + if attn_mask is not None and attn_mask.shape[0] > x.shape[0]: + attn_mask = self.attn_mask[:x.shape[0], :x.shape[0]] + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """ Transformer encoder module + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + use_gc: bool = False): + super().__init__() + self.use_gc = use_gc + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class AttentionPool2d(nn.Module): + """ Pool layer with attention module + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :].to(x.dtype) + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + return x.permute(1, 0, 2).contiguous() + + +class CrossAttention(nn.Module): + """ Cross attention module with query and context as input + Adapted from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py + """ + + def __init__(self, + dim, + *, + context_dim=None, + dim_head=64, + heads=8, + parallel_ff=False, + ff_mult=4, + norm_context=False): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + context_dim = dim if context_dim is None else context_dim + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm( + context_dim) if norm_context else nn.Identity() + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + ff_inner_dim = ff_mult * dim + self.ff = nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False)) if parallel_ff else None + + def forward(self, x, context): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + x = self.norm(x) + context = self.context_norm(context) + + q = self.to_q(x) + q = q.view(q.shape[0], q.shape[1], self.heads, + -1).permute(0, 2, 1, 3).contiguous() + q = q * self.scale + k, v = self.to_kv(context).chunk(2, dim=-1) + sim = torch.einsum('b h i d, b j d -> b h i j', q, k) + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + out = torch.einsum('b h i j, b j d -> b h i d', attn, v) + out = out.permute(0, 2, 1, + 3).contiguous().reshape(out.shape[0], out.shape[2], + -1) + out = self.to_out(out) + if self.ff is not None: + out = out + self.ff(x) + return out + + +class ModifiedResNet(nn.Module): + """ Modified ResNet backbone + From https://github.com/openai/CLIP/blob/main/clip/model.py + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + self._inplanes = width + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class VisualTransformer(nn.Module): + """ ViT transformer backbone + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int, use_gc: bool): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + z = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([self.class_embedding.to(x.dtype) + z, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + return x + + +class GEVL(nn.Module): + """ Generative vision-language model + Support learning from both generative and contrastive loss. + Given image and text input, it could output the features of + image and text respectively. Furthermore, caption could also + be produced when image input is available. + """ + + def __init__(self, embed_dim: int, image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], + int], vision_width: int, + vision_patch_size: int, context_length: int, vocab_size: int, + transformer_width: int, transformer_heads: int, + transformer_layers: int, use_gc: bool, tokenizer): + nn.Module.__init__(self) + self.context_length = context_length + self.vis_token_size = context_length + self.tokenizer = tokenizer + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + use_gc=use_gc) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + use_gc=use_gc) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.vis_token_projection = nn.Parameter( + torch.empty(embed_dim, transformer_width)) + nn.init.normal_( + self.vis_token_projection, std=self.transformer.width**-0.5) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.decoder = Transformer( + width=transformer_width, + layers=4, + heads=transformer_heads, + attn_mask=self.build_attention_mask( + self.vis_token_size + self.context_length, + self.vis_token_size), + use_gc=use_gc) + self.to_logits = nn.Sequential( + LayerNorm(transformer_width), + nn.Linear(transformer_width, transformer_width), + nn.Linear(transformer_width, vocab_size, bias=False)) + self.gen_logit_scale = nn.Parameter( + torch.ones([]) * np.log(np.log(vocab_size))) + self.bias = nn.Parameter(torch.ones(vocab_size)) + self.to_logits[-1].weight = self.token_embedding.weight + self.to_logits[-1].bias = self.bias + self.img_queries = nn.Parameter( + torch.randn(self.vis_token_size, transformer_width)) + self.img_attn_pool = CrossAttention( + dim=transformer_width, norm_context=True) + self.img_attn_pool_norm = LayerNorm(transformer_width) + + def build_attention_mask(self, seq_length=None, prefix_length=0): + seq_length = self.context_length if seq_length is None else seq_length + mask = torch.empty(seq_length, seq_length) + mask.fill_(torch.tensor(torch.finfo(torch.float16).min)) + mask.triu_(1) + if prefix_length > 0: + mask[:prefix_length, :prefix_length] = 0 + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, return_tokens=False): + image_outputs = self.visual(image) + image_features = image_outputs[:, 0, :] + image_features = image_features / image_features.norm( + dim=-1, p=2, keepdim=True) + if return_tokens: + image_tokens = image_outputs[:, 1:, :] @ self.vis_token_projection + return image_features, image_tokens + else: + return image_features + + def encode_text(self, text, return_tokens=False): + x = self.token_embedding(text) + x = x + self.positional_embedding[:x.shape[1], :] + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_final(x) + text_features = x[torch.arange(x.shape[0]), + text.argmax(dim=-1), ...] @ self.text_projection + text_features = text_features / text_features.norm( + dim=-1, p=2, keepdim=True) + if return_tokens: + text_tokens = x + return text_features, text_tokens + else: + return text_features + + def image_to_text(self, image): + image_features, image_tokens = self.encode_image( + image, return_tokens=True) + img_queries = self.img_queries.expand(image_tokens.shape[0], -1, -1) + img_token_features = self.img_attn_pool(img_queries, image_tokens) + img_token_features = self.img_attn_pool_norm(img_token_features) + sot_token = self.tokenizer.encoder['<|startoftext|>'] + eot_token = self.tokenizer.encoder['<|endoftext|>'] + text_input = image.new_ones( + image.shape[0], 1, dtype=torch.long) * sot_token + input_tokens = img_token_features + pred_tokens = [] + for text_idx in range(self.context_length): + text_features, text_tokens = self.encode_text( + text_input, return_tokens=True) + input_tokens = torch.cat([img_token_features, text_tokens], axis=1) + out_embs = self.decoder(input_tokens.permute(1, 0, 2).contiguous()) + gen_logits = self.to_logits(out_embs[-1:, ...]) + probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1) + pred = torch.argmax( + probs * (1.0 + torch.rand_like(probs)), axis=-1) + pred_tokens.append(pred) + text_input = torch.cat( + [text_input, pred.permute(1, 0).contiguous()], axis=1) + pred_text_tokens = torch.cat(pred_tokens, axis=0).permute(1, 0) + text_list = [] + for out_tokens in pred_text_tokens: + tokens = [] + for x in out_tokens: + if x >= eot_token or x <= 0: + break + tokens.append(int(x)) + out_text = self.tokenizer.decode(tokens) + out_text = out_text.strip() + text_list.append(out_text) + return image_features, text_list[0] + + +class GEMMModel(nn.Module): + """ Generative multi-modal model, wrapper of GEVL module. + It takes image or text or both of them as input, and output + features of input or caption when image input is available. + """ + + def __init__(self, model_dir): + super().__init__() + with open('{}/encoder_config.json'.format(model_dir), 'r') as f: + model_config = json.loads(f.read()) + model_name = list(model_config.keys())[0] + config_args = model_config[model_name] + bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz') + self.tokenizer = SimpleTokenizer(bpe_path) + self.model = GEVL(*config_args, self.tokenizer) + + def tokenize(self, text_str): + text_tensor = clip_tokenize(self.tokenizer, [text_str])[0] + return text_tensor + + def parse_feat(self, feat): + out = feat.cpu().numpy() + return out + + @torch.no_grad() + def forward(self, image=None, text=None, captioning=True): + img_feature, text_feature, caption = None, None, None + if captioning and image is not None: + img_feature, caption = self.model.image_to_text(image) + elif image is not None: + img_feature = self.parse_feat(self.model.encode_image(image)) + if text is not None: + text_feature = self.parse_feat(self.model.encode_text(text)) + out = { + 'image_feature': img_feature, + 'text_feature': text_feature, + 'caption': caption, + } + return out diff --git a/modelscope/models/multi_modal/gemm/gemm_model.py b/modelscope/models/multi_modal/gemm/gemm_model.py new file mode 100644 index 00000000..a5380858 --- /dev/null +++ b/modelscope/models/multi_modal/gemm/gemm_model.py @@ -0,0 +1,88 @@ +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms as T + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.gemm.gemm_base import GEMMModel +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['GEMMForMultiModalEmbedding'] + + +@MODELS.register_module( + Tasks.generative_multi_modal_embedding, module_name=Models.gemm) +class GEMMForMultiModalEmbedding(TorchModel): + """ Generative multi-modal model for multi-modal embedding + Inputs could be image or text or both of them. + Outputs could be features of input image or text, + image caption could also be produced when image is available. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.gemm_model = GEMMModel(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.gemm_model.load_state_dict(pretrained_params) + self.gemm_model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.gemm_model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + logger.info('Use CPU for inference') + self.img_preprocessor = T.Compose([ + T.Resize(224), + T.CenterCrop(224), + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + ]) + + def parse_image(self, input_img): + if input_img is None: + return None + input_img = LoadImage.convert_to_img(input_img) + img_tensor = self.img_preprocessor(input_img)[None, ...] + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) + return img_tensor + + def parse_text(self, text_str): + if text_str is None: + return None + if isinstance(text_str, str): + text_ids_tensor = self.gemm_model.tokenize(text_str) + else: + raise TypeError(f'text should be str, but got {type(text_str)}') + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + return text_ids_tensor.view(1, -1) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + image = self.parse_image(input.get('image', input.get('img', None))) + text = self.parse_text(input.get('text', input.get('txt', None))) + captioning = input.get('captioning', False) is True + out = self.gemm_model(image, text, captioning) + output = { + OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), + OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None), + OutputKeys.CAPTION: out.get('caption', None) + } + return output diff --git a/modelscope/models/multi_modal/gemm/tokenizer.py b/modelscope/models/multi_modal/gemm/tokenizer.py new file mode 100644 index 00000000..af962ceb --- /dev/null +++ b/modelscope/models/multi_modal/gemm/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP Tokenizer +Adapted from https://github.com/openai/CLIP. +Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + error_list = [] + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + error_list.append(err) + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + if len(error_list) > 100: + print(error_list[-1]) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +def clip_tokenize(tokenizer, texts, context_length=77, truncate=True): + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = tokenizer.encoder['<|startoftext|>'] + eot_token = tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/outputs.py b/modelscope/outputs.py index da770b70..921b2bc3 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -271,6 +271,15 @@ TASK_OUTPUTS = { Tasks.multi_modal_embedding: [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING], + # generative multi-modal embedding result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D], + # "caption": "this is an image caption text." + # } + Tasks.generative_multi_modal_embedding: + [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], + # visual grounding result for single sample # { # "boxes": [ diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 58730d9a..224d6379 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -62,6 +62,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), + Tasks.generative_multi_modal_embedding: + (Pipelines.generative_multi_modal_embedding, + 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' + ), Tasks.visual_question_answering: (Pipelines.visual_question_answering, 'damo/mplug_visual-question-answering_coco_large_en'), diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 76c22238..0f3c0444 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1,6 +1,7 @@ try: from .image_captioning_pipeline import ImageCaptionPipeline from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline + from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline except ModuleNotFoundError as e: diff --git a/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..f5a180b6 --- /dev/null +++ b/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py @@ -0,0 +1,32 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.generative_multi_modal_embedding, + module_name=Pipelines.generative_multi_modal_embedding) +class GEMMMultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a generative multimodal embedding pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 893db798..44cd87f4 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -75,6 +75,7 @@ class MultiModalTasks(object): visual_grounding = 'visual-grounding' text_to_image_synthesis = 'text-to-image-synthesis' multi_modal_embedding = 'multi-modal-embedding' + generative_multi_modal_embedding = 'generative-multi-modal-embedding' visual_question_answering = 'visual-question-answering' diff --git a/tests/pipelines/test_generative_multi_modal_embedding.py b/tests/pipelines/test_generative_multi_modal_embedding.py new file mode 100644 index 00000000..ccca8f4e --- /dev/null +++ b/tests/pipelines/test_generative_multi_modal_embedding.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class GEMMMultiModalEmbeddingTest(unittest.TestCase): + model_id = 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' + test_input = { + 'image': 'data/test/images/generative_multimodal.jpg', + 'text': + 'interior design of modern living room with fireplace in a new house', + 'captioning': False + } + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run(self): + generative_multi_modal_embedding_pipeline = pipeline( + Tasks.generative_multi_modal_embedding, model=self.model_id) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=model) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_captioning(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'image': self.test_input['image'], 'captioning': True} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_only_image(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'image': self.test_input['image'], 'captioning': False} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_only_text(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'text': self.test_input['text']} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + +if __name__ == '__main__': + unittest.main()