Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9908976master
| @@ -59,6 +59,7 @@ class Models(object): | |||
| gemm = 'gemm-generative-multi-modal' | |||
| mplug = 'mplug' | |||
| diffusion = 'diffusion-text-to-image-synthesis' | |||
| team = 'team-multi-modal-similarity' | |||
| video_clip = 'video-clip-multi-modal-embedding' | |||
| @@ -166,6 +167,7 @@ class Pipelines(object): | |||
| visual_question_answering = 'visual-question-answering' | |||
| visual_grounding = 'visual-grounding' | |||
| visual_entailment = 'visual-entailment' | |||
| multi_modal_similarity = 'multi-modal-similarity' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||
| from .clip import CLIPForMultiModalEmbedding | |||
| from .gemm import GEMMForMultiModalEmbedding | |||
| from .team import TEAMForMultiModalSimilarity | |||
| from .diffusion import DiffusionForTextToImageSynthesis | |||
| from .mmr import VideoCLIPForMultiModalEmbedding | |||
| from .mplug_for_all_tasks import MPlugForAllTasks | |||
| @@ -19,6 +20,7 @@ else: | |||
| 'clip': ['CLIPForMultiModalEmbedding'], | |||
| 'diffusion': ['DiffusionForTextToImageSynthesis'], | |||
| 'gemm': ['GEMMForMultiModalEmbedding'], | |||
| 'team': ['TEAMForMultiModalSimilarity'], | |||
| 'mmr': ['VideoCLIPForMultiModalEmbedding'], | |||
| 'mplug_for_all_tasks': ['MPlugForAllTasks'], | |||
| 'ofa_for_all_tasks': ['OfaForAllTasks'], | |||
| @@ -0,0 +1 @@ | |||
| from .team_model import TEAMForMultiModalSimilarity | |||
| @@ -0,0 +1,126 @@ | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from PIL import Image | |||
| from tokenizers import BertWordPieceTokenizer | |||
| from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .utils import TEAM, BertWrapper, CLIPVisionWrapper, CrossLayer | |||
| logger = get_logger() | |||
| __all__ = ['TEAMForMultiModalSimilarity'] | |||
| @MODELS.register_module(Tasks.multi_modal_similarity, module_name=Models.team) | |||
| class TEAMForMultiModalSimilarity(TorchModel): | |||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
| super().__init__( | |||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
| text_model = BertWrapper( | |||
| config_json='{}/text_config.json'.format(model_dir), | |||
| feat_dim=768, | |||
| token_dim=1024) | |||
| text_model.bert.cls = None | |||
| image_model = CLIPVisionWrapper() | |||
| self.model = TEAM( | |||
| text_model, | |||
| image_model, | |||
| pretrained='{}/{}'.format(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE)) | |||
| self.model.eval() | |||
| self.device_id = device_id | |||
| if self.device_id >= 0 and torch.cuda.is_available(): | |||
| self.model.to('cuda:{}'.format(self.device_id)) | |||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||
| else: | |||
| self.device_id = -1 | |||
| logger.info('Use CPU for inference') | |||
| self.text_tokenizer = BertWordPieceTokenizer( | |||
| '{}/{}'.format(model_dir, ModelFile.VOCAB_FILE), lowercase=False) | |||
| self.text_tokenizer.enable_truncation(max_length=30) | |||
| norm_op = Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)) | |||
| self.img_preprocessor = Compose([ | |||
| Resize((224, 224), interpolation=Image.BICUBIC), | |||
| ToTensor(), norm_op | |||
| ]) | |||
| def tokenize_text(self, text_str): | |||
| tokens = self.text_tokenizer.encode(text_str) | |||
| max_tokens = 30 | |||
| text_ids_tensor = torch.zeros((1, max_tokens)).long() | |||
| text_mask_tensor = torch.zeros((1, max_tokens)) | |||
| text_ids, text_mask = tokens.ids, tokens.attention_mask | |||
| text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids) | |||
| text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask) | |||
| return text_ids_tensor, text_mask_tensor | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| if 'img' in input and input['img'] is not None: | |||
| input_img = input['img'] | |||
| input_img = LoadImage.convert_to_img(input_img) | |||
| img_tensor = self.img_preprocessor(input_img)[None, ...] | |||
| if self.device_id >= 0: | |||
| img_tensor = img_tensor.to('cuda:{}'.format( | |||
| self.device_id)) | |||
| _, _, image_feature, image_tensors = self.model.get_feature( | |||
| None, None, img_tensor) | |||
| image_feature = image_feature.cpu().numpy() | |||
| else: | |||
| image_feature, image_tensors = None, None | |||
| if 'text' in input and input['text'] is not None: | |||
| text_str = input['text'] | |||
| if isinstance(text_str, str): | |||
| text_ids_tensor, text_mask_tensor = self.tokenize_text( | |||
| text_str) | |||
| else: | |||
| raise TypeError( | |||
| f'text should be str, but got {type(text_str)}') | |||
| if self.device_id >= 0: | |||
| text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( | |||
| self.device_id)) | |||
| text_mask_tensor = text_mask_tensor.to('cuda:{}'.format( | |||
| self.device_id)) | |||
| text_feature, text_tensors, _, _ = self.model.get_feature( | |||
| text_ids_tensor, text_mask_tensor, None) | |||
| text_feature = text_feature.cpu().numpy() | |||
| else: | |||
| text_tensors, text_mask_tensor = None, None | |||
| if text_tensors is not None and text_mask_tensor is not None and image_tensors is not None: | |||
| score = self.model.get_cross_score(text_tensors, | |||
| text_mask_tensor, | |||
| image_tensors)[0].item() | |||
| else: | |||
| score = None | |||
| output = { | |||
| OutputKeys.IMG_EMBEDDING: image_feature, | |||
| OutputKeys.TEXT_EMBEDDING: text_feature, | |||
| OutputKeys.SCORES: score | |||
| } | |||
| return output | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,326 @@ | |||
| """ Generative Multimodal Model | |||
| Base Transformer code is adapted from https://github.com/openai/CLIP/, | |||
| originally MIT License, Copyright (c) 2021 OpenAI, | |||
| """ | |||
| from collections import OrderedDict | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as checkpoint | |||
| from torch import nn | |||
| from transformers import BertConfig, BertForMaskedLM | |||
| class LayerNorm(nn.LayerNorm): | |||
| """Subclass torch's LayerNorm to handle fp16.""" | |||
| def forward(self, x: torch.Tensor): | |||
| orig_type = x.dtype | |||
| ret = super().forward(x.type(torch.float32)) | |||
| return ret.type(orig_type) | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, | |||
| d_model: int, | |||
| n_head: int, | |||
| attn_mask: torch.Tensor = None): | |||
| super().__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.attn_mask = attn_mask | |||
| def attention(self, x: torch.Tensor): | |||
| self.attn_mask = self.attn_mask.to( | |||
| dtype=x.dtype, | |||
| device=x.device) if self.attn_mask is not None else None | |||
| return self.attn( | |||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
| def forward(self, x: torch.Tensor): | |||
| x = x + self.attention(self.ln_1(x)) | |||
| x = x + self.mlp(self.ln_2(x)) | |||
| return x | |||
| class Transformer(nn.Module): | |||
| def __init__(self, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| attn_mask: torch.Tensor = None, | |||
| use_gc=False): | |||
| super().__init__() | |||
| self.use_gc = use_gc | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential(*[ | |||
| ResidualAttentionBlock(width, heads, attn_mask) | |||
| for _ in range(layers) | |||
| ]) | |||
| def forward(self, x: torch.Tensor): | |||
| if self.use_gc: | |||
| for each_block in self.resblocks: | |||
| x = checkpoint.checkpoint(each_block, x) | |||
| return x | |||
| else: | |||
| return self.resblocks(x) | |||
| class VisionTransformer(nn.Module): | |||
| def __init__(self, | |||
| input_resolution: int, | |||
| patch_size: int, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| output_dim: int, | |||
| use_gc=False): | |||
| super().__init__() | |||
| self.input_resolution = input_resolution | |||
| self.output_dim = output_dim | |||
| self.conv1 = nn.Conv2d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=patch_size, | |||
| stride=patch_size, | |||
| bias=False) | |||
| scale = width**-0.5 | |||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
| (input_resolution // patch_size)**2 + 1, width)) | |||
| self.ln_pre = LayerNorm(width) | |||
| self.transformer = Transformer(width, layers, heads, use_gc=use_gc) | |||
| self.ln_post = LayerNorm(width) | |||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
| def forward(self, x: torch.Tensor): | |||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| class_embedding = self.class_embedding.to(x.dtype) + \ | |||
| torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
| x = torch.cat([class_embedding, x], | |||
| dim=1) # shape = [*, grid ** 2 + 1, width] | |||
| x = x + self.positional_embedding.to(x.dtype) | |||
| x = self.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| x = self.ln_post(x[:, 0, :]) | |||
| if self.proj is not None: | |||
| x = x @ self.proj | |||
| return x | |||
| class CLIPVisionWrapper(nn.Module): | |||
| def __init__(self, ): | |||
| super().__init__() | |||
| self.vision_transformer = VisionTransformer( | |||
| input_resolution=224, | |||
| patch_size=14, | |||
| width=1024, | |||
| layers=24, | |||
| heads=16, | |||
| output_dim=768) | |||
| def forward(self, x): | |||
| x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \ | |||
| torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
| x = torch.cat([class_embedding, x], | |||
| dim=1) # shape = [*, grid ** 2 + 1, width] | |||
| x = x + self.vision_transformer.positional_embedding.to(x.dtype) | |||
| x = self.vision_transformer.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.vision_transformer.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| x_tensor = x.clone() | |||
| x = self.vision_transformer.ln_post(x[:, 0, :]) | |||
| if self.vision_transformer.proj is not None: | |||
| x = x @ self.vision_transformer.proj | |||
| return x, x_tensor | |||
| class BertWrapper(nn.Module): | |||
| def __init__(self, config_json, feat_dim, token_dim): | |||
| super(BertWrapper, self).__init__() | |||
| bert_config = BertConfig.from_json_file(config_json) | |||
| self.bert = BertForMaskedLM(bert_config).bert | |||
| self.projector = nn.Linear(768, feat_dim, bias=False) | |||
| self.projector_token_embeds = nn.Linear(768, token_dim) | |||
| def forward(self, input_ids, attention_mask): | |||
| trans_features = { | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask | |||
| } | |||
| output_states = self.bert(**trans_features, return_dict=False) | |||
| output_tokens = output_states[0] | |||
| cls_tokens = output_tokens[:, 0, :] # CLS token is first token | |||
| return self.projector(cls_tokens), self.projector_token_embeds( | |||
| output_tokens) | |||
| class Mlp(nn.Module): | |||
| def __init__(self, | |||
| in_features, | |||
| hidden_features=None, | |||
| out_features=None, | |||
| act_layer=nn.GELU, | |||
| drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||
| self.drop = nn.Dropout(drop) | |||
| def forward(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| x = self.drop(x) | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| class CrossLayer(nn.Module): | |||
| def __init__(self, feat_dim, mlp_ratio): | |||
| super(CrossLayer, self).__init__() | |||
| self.norm1 = nn.LayerNorm(feat_dim) | |||
| self.norm2 = nn.LayerNorm(feat_dim) | |||
| self.norm3 = nn.LayerNorm(feat_dim) | |||
| self.self_attn = nn.MultiheadAttention( | |||
| embed_dim=feat_dim, num_heads=16) | |||
| self.cross_attn = nn.MultiheadAttention( | |||
| embed_dim=feat_dim, num_heads=16) | |||
| self.ffn = Mlp( | |||
| in_features=feat_dim, | |||
| hidden_features=feat_dim * mlp_ratio, | |||
| drop=0.1) | |||
| self.dropout1 = nn.Dropout(0.1) | |||
| self.dropout2 = nn.Dropout(0.1) | |||
| self.dropout3 = nn.Dropout(0.1) | |||
| def forward(self, text_tensors, text_masks, image_tensors, | |||
| retrieved_tensors): | |||
| retrieved_tensors_res = self.norm1(retrieved_tensors) | |||
| retrieved_tensors_res = self.self_attn( | |||
| (text_tensors + retrieved_tensors_res).permute(1, 0, 2), | |||
| (text_tensors + retrieved_tensors_res).permute(1, 0, 2), | |||
| retrieved_tensors_res.permute(1, 0, 2), | |||
| key_padding_mask=(text_masks == 0), | |||
| )[0].permute(1, 0, 2) | |||
| retrieved_tensors = retrieved_tensors + self.dropout1( | |||
| retrieved_tensors_res) | |||
| retrieved_tensors_res = self.norm2(retrieved_tensors) | |||
| retrieved_tensors_res = self.cross_attn( | |||
| (text_tensors + retrieved_tensors_res).permute(1, 0, 2), | |||
| image_tensors.permute(1, 0, 2), | |||
| image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2) | |||
| retrieved_tensors = retrieved_tensors + self.dropout2( | |||
| retrieved_tensors_res) | |||
| retrieved_tensors_res = self.norm3(retrieved_tensors) | |||
| retrieved_tensors = retrieved_tensors + self.dropout3( | |||
| self.ffn(retrieved_tensors_res)) | |||
| return retrieved_tensors | |||
| class TEAM(nn.Module): | |||
| def __init__(self, text_model, image_model, pretrained): | |||
| super(TEAM, self).__init__() | |||
| self.text_model = text_model | |||
| self.image_model = image_model | |||
| self.cross_model = nn.ModuleList( | |||
| [CrossLayer(feat_dim=1024, mlp_ratio=2)]) | |||
| self.image_tensor_fc = nn.Linear(1024, 768) | |||
| self.text_tensor_fc = nn.Linear(1024, 768) | |||
| params = torch.load(pretrained, 'cpu') | |||
| self.load_state_dict(params, strict=True) | |||
| def get_feature(self, text_data=None, text_mask=None, img_tensor=None): | |||
| if text_data is not None: | |||
| text_feature, text_tensors = self.text_model(text_data, text_mask) | |||
| text_feature = F.normalize(text_feature, p=2.0, dim=1) | |||
| else: | |||
| text_feature, text_tensors = None, None | |||
| if img_tensor is not None: | |||
| image_feature, image_tensors = self.image_model(img_tensor) | |||
| image_feature = F.normalize(image_feature, p=2.0, dim=1) | |||
| else: | |||
| image_feature, image_tensors = None, None | |||
| return text_feature, text_tensors, image_feature, image_tensors | |||
| def get_cross_score(self, text_tensors, text_mask, image_tensors): | |||
| retrieved_tensors = torch.zeros_like(text_tensors) | |||
| pair_score_list = [] | |||
| text_tensors_proj = self.text_tensor_fc(text_tensors) | |||
| text_mask_float = text_mask.type(text_tensors_proj.dtype) | |||
| for each_cross_model in self.cross_model: | |||
| retrieved_tensors = each_cross_model(text_tensors, text_mask, | |||
| image_tensors, | |||
| retrieved_tensors) | |||
| retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors) | |||
| pair_score = torch.sum( | |||
| F.normalize(retrieved_tensors_proj, p=2.0, dim=2) | |||
| * F.normalize(text_tensors_proj, p=2.0, dim=2), | |||
| dim=2) | |||
| pair_score_reduced = torch.sum( | |||
| pair_score * text_mask_float, dim=1) / torch.clamp( | |||
| torch.sum(text_mask_float, dim=1), min=1.0) | |||
| pair_score_list.append(pair_score_reduced) | |||
| return pair_score_list | |||
| @@ -499,6 +499,15 @@ TASK_OUTPUTS = { | |||
| Tasks.generative_multi_modal_embedding: | |||
| [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION], | |||
| # multi-modal similarity result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| # "text_embedding": np.array with shape [1, D], | |||
| # "similarity": float | |||
| # } | |||
| Tasks.multi_modal_similarity: | |||
| [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||
| # VQA result for a sample | |||
| # {"text": "this is a text answser. "} | |||
| Tasks.visual_question_answering: [OutputKeys.TEXT], | |||
| @@ -79,6 +79,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| (Pipelines.generative_multi_modal_embedding, | |||
| 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' | |||
| ), | |||
| Tasks.multi_modal_similarity: | |||
| (Pipelines.multi_modal_similarity, | |||
| 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'), | |||
| Tasks.visual_question_answering: | |||
| (Pipelines.visual_question_answering, | |||
| 'damo/mplug_visual-question-answering_coco_large_en'), | |||
| @@ -0,0 +1,31 @@ | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input, Model, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.multi_modal_similarity, module_name=Pipelines.multi_modal_similarity) | |||
| class TEAMMultiModalSimilarityPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a multimodal similarity pipeline | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| return self.model(input) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -117,6 +117,7 @@ class MultiModalTasks(object): | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| generative_multi_modal_embedding = 'generative-multi-modal-embedding' | |||
| multi_modal_similarity = 'multi-modal-similarity' | |||
| visual_question_answering = 'visual-question-answering' | |||
| visual_entailment = 'visual-entailment' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class MultiModalSimilarityTest(unittest.TestCase): | |||
| model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' | |||
| test_input = { | |||
| 'img': 'data/test/images/generative_multimodal.jpg', | |||
| 'text': '起居室照片' | |||
| } | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| multi_modal_similarity_pipeline = pipeline( | |||
| Tasks.multi_modal_similarity, model=self.model_id) | |||
| output = multi_modal_similarity_pipeline(self.test_input) | |||
| print(output) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| multi_modal_similarity_pipeline = pipeline( | |||
| task=Tasks.multi_modal_similarity) | |||
| output = multi_modal_similarity_pipeline(self.test_input) | |||
| print(output) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| multi_modal_similarity_pipeline = pipeline( | |||
| task=Tasks.multi_modal_similarity, model=model) | |||
| output = multi_modal_similarity_pipeline(self.test_input) | |||
| print(output) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||