Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9458640master
| @@ -124,3 +124,4 @@ replace.sh | |||
| # Pytorch | |||
| *.pth | |||
| *.pt | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:33e21c16d5388684b61d7251b9d4e418f8146c3ba3fa400ebd8d913058687cfc | |||
| size 431888 | |||
| @@ -34,6 +34,7 @@ class Models(object): | |||
| gemm = 'gemm-generative-multi-modal' | |||
| mplug = 'mplug' | |||
| imagen = 'imagen-text-to-image-synthesis' | |||
| video_clip = 'video-clip-multi-modal-embedding' | |||
| class TaskModels(object): | |||
| @@ -99,6 +100,7 @@ class Pipelines(object): | |||
| generative_multi_modal_embedding = 'generative-multi-modal-embedding' | |||
| visual_question_answering = 'visual-question-answering' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| class Trainers(object): | |||
| @@ -5,10 +5,10 @@ from .base import Model | |||
| from .builder import MODELS, build_model | |||
| try: | |||
| from .audio.ans.frcrn import FRCRNModel | |||
| from .audio.asr import GenericAutomaticSpeechRecognition | |||
| from .audio.tts import SambertHifigan | |||
| from .audio.kws import GenericKeyWordSpotting | |||
| from .audio.ans.frcrn import FRCRNModel | |||
| from .audio.tts import SambertHifigan | |||
| except ModuleNotFoundError as e: | |||
| print(AUDIO_IMPORT_ERROR.format(e)) | |||
| @@ -29,8 +29,8 @@ try: | |||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||
| SpaceForDialogModeling, SpaceForDialogStateTracking, | |||
| StructBertForMaskedLM, VecoForMaskedLM) | |||
| from .nlp.heads import (SequenceClassificationHead) | |||
| from .nlp.backbones import (SbertModel) | |||
| from .nlp.backbones import SbertModel | |||
| from .nlp.heads import SequenceClassificationHead | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'pytorch'": | |||
| pass | |||
| @@ -1,6 +1,8 @@ | |||
| from .clip.clip_model import CLIPForMultiModalEmbedding | |||
| from .gemm.gemm_model import GEMMForMultiModalEmbedding | |||
| from .imagen.imagen_model import ImagenForTextToImageSynthesis | |||
| from .mmr.models.clip_for_multi_model_video_embedding import \ | |||
| VideoCLIPForMultiModalEmbedding | |||
| from .mplug_for_visual_question_answering import \ | |||
| MPlugForVisualQuestionAnswering | |||
| from .ofa_for_image_captioning_model import OfaForImageCaptioning | |||
| @@ -784,7 +784,7 @@ class BertModel(nn.Module): | |||
| elif config.transformer_type.lower() == 'act': | |||
| self.encoder = BERTEncoderACT(config) | |||
| elif config.transformer_type.lower() == 'textnas': | |||
| from textnas_final import op_dict, input_dict, skip_dict | |||
| from textnas_final import input_dict, op_dict, skip_dict | |||
| self.encoder = TextNASEncoder(config, op_dict, input_dict, | |||
| skip_dict) | |||
| else: | |||
| @@ -0,0 +1,114 @@ | |||
| import cv2 | |||
| import numpy as np | |||
| import torch as th | |||
| from PIL import Image | |||
| from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, | |||
| Normalize, Resize, ToTensor) | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class RawVideoExtractorCV2(): | |||
| def __init__( | |||
| self, | |||
| centercrop=False, | |||
| size=224, | |||
| frame_rate=-1, | |||
| ): | |||
| self.centercrop = centercrop | |||
| self.size = size | |||
| self.framerate = frame_rate | |||
| self.transform = self._transform(self.size) | |||
| def _transform(self, n_px): | |||
| return Compose([ | |||
| Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||
| CenterCrop(n_px), | |||
| lambda image: image.convert('RGB'), | |||
| ToTensor(), | |||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| def video_to_tensor(self, | |||
| video_file, | |||
| preprocess, | |||
| sample_fp=0, | |||
| start_time=None, | |||
| end_time=None): | |||
| if start_time is not None or end_time is not None: | |||
| assert isinstance(start_time, int) and isinstance(end_time, int) \ | |||
| and start_time > -1 and end_time > start_time | |||
| assert sample_fp > -1 | |||
| # Samples a frame sample_fp X frames. | |||
| cap = cv2.VideoCapture(video_file) | |||
| frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |||
| if fps == 0: | |||
| logger.info(f'{video_file} with fps 0!!!') | |||
| total_duration = (frameCount + fps - 1) // fps | |||
| start_sec, end_sec = 0, total_duration | |||
| if start_time is not None: | |||
| start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration | |||
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) | |||
| interval = 1 | |||
| if sample_fp > 0: | |||
| interval = fps // sample_fp | |||
| else: | |||
| sample_fp = fps | |||
| if interval == 0: | |||
| interval = 1 | |||
| inds = [ind for ind in np.arange(0, fps, interval)] | |||
| assert len(inds) >= sample_fp | |||
| inds = inds[:sample_fp] | |||
| ret = True | |||
| images = [] | |||
| for sec in np.arange(start_sec, end_sec + 1): | |||
| if not ret: | |||
| break | |||
| sec_base = int(sec * fps) | |||
| for ind in inds: | |||
| cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) | |||
| ret, frame = cap.read() | |||
| if not ret: | |||
| break | |||
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||
| images.append( | |||
| preprocess(Image.fromarray(frame_rgb).convert('RGB'))) | |||
| cap.release() | |||
| if len(images) > 0: | |||
| video_data = th.tensor(np.stack(images)) | |||
| else: | |||
| video_data = th.zeros(1) | |||
| return {'video': video_data} | |||
| def get_video_data(self, video_path, start_time=None, end_time=None): | |||
| image_input = self.video_to_tensor( | |||
| video_path, | |||
| self.transform, | |||
| sample_fp=self.framerate, | |||
| start_time=start_time, | |||
| end_time=end_time) | |||
| return image_input | |||
| def process_raw_data(self, raw_video_data): | |||
| tensor_size = raw_video_data.size() | |||
| tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], | |||
| tensor_size[-1]) | |||
| return tensor | |||
| # An ordinary video frame extractor based CV2 | |||
| RawVideoExtractor = RawVideoExtractorCV2 | |||
| @@ -0,0 +1,218 @@ | |||
| import os | |||
| import random | |||
| from os.path import exists | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.multi_modal.mmr.dataloaders.rawvideo_util import \ | |||
| RawVideoExtractor | |||
| from modelscope.models.multi_modal.mmr.models.modeling import CLIP4Clip | |||
| from modelscope.models.multi_modal.mmr.models.tokenization_clip import \ | |||
| SimpleTokenizer as ClipTokenizer | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @MODELS.register_module( | |||
| Tasks.video_multi_modal_embedding, module_name=Models.video_clip) | |||
| class VideoCLIPForMultiModalEmbedding(Model): | |||
| def __init__(self, model_dir, device_id=-1): | |||
| super().__init__(model_dir=model_dir, device_id=device_id) | |||
| # model config parameters | |||
| with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: | |||
| model_config = json.load(json_file) | |||
| model_config = model_config['paras'] | |||
| model_config['model_dir'] = model_dir | |||
| self.SPECIAL_TOKEN = { | |||
| 'CLS_TOKEN': '<|startoftext|>', | |||
| 'SEP_TOKEN': '<|endoftext|>', | |||
| 'MASK_TOKEN': '[MASK]', | |||
| 'UNK_TOKEN': '[UNK]', | |||
| 'PAD_TOKEN': '[PAD]' | |||
| } | |||
| self.max_words = model_config['max_words'] | |||
| self.max_frames = model_config['max_frames'] | |||
| self.feature_framerate = model_config['feature_framerate'] | |||
| self.image_resolution = 224 | |||
| self.device = model_config['device'] | |||
| self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}' | |||
| self.tokenizer = ClipTokenizer(model_dir) | |||
| self.rawVideoExtractor = RawVideoExtractor( | |||
| frame_rate=self.feature_framerate, size=self.image_resolution) | |||
| self.local_transform = self.rawVideoExtractor.transform | |||
| self.model = CLIP4Clip(model_config) | |||
| if hasattr(self.model, 'module'): | |||
| self.model = self.model.module.to(self.device) | |||
| else: | |||
| self.model = self.model.to(self.device) | |||
| if self.init_model: | |||
| assert exists(self.init_model) | |||
| model_state_dict = torch.load(self.init_model, map_location='cpu') | |||
| self.model.load_state_dict(model_state_dict, strict=False) | |||
| self.model.to(self.device) | |||
| def _get_text(self, caption, tokenizer, enable_zh=False): | |||
| if len(caption) == 3: | |||
| _caption_text, s, e = caption | |||
| elif len(caption) == 4: | |||
| _caption_text, s, e, pos = caption | |||
| else: | |||
| NotImplementedError | |||
| if isinstance(_caption_text, list): | |||
| caption_text = random.choice(_caption_text) | |||
| else: | |||
| caption_text = _caption_text | |||
| if enable_zh: | |||
| _token = tokenizer.encode(caption_text) | |||
| input_ids = _token.ids | |||
| input_mask = _token.attention_mask | |||
| segment_ids = _token.type_ids | |||
| else: | |||
| words = tokenizer.tokenize(caption_text) | |||
| words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words | |||
| total_length_with_CLS = self.max_words - 1 | |||
| if len(words) > total_length_with_CLS: | |||
| words = words[:total_length_with_CLS] | |||
| words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']] | |||
| input_ids = tokenizer.convert_tokens_to_ids(words) | |||
| input_mask = [1] * len(input_ids) | |||
| segment_ids = [0] * len(input_ids) | |||
| while len(input_ids) < self.max_words: | |||
| input_ids.append(0) | |||
| input_mask.append(0) | |||
| segment_ids.append(0) | |||
| assert len(input_ids) == self.max_words | |||
| assert len(input_mask) == self.max_words | |||
| assert len(segment_ids) == self.max_words | |||
| pairs_text = np.array(input_ids) | |||
| pairs_mask = np.array(input_mask) | |||
| pairs_segment = np.array(segment_ids) | |||
| return pairs_text, pairs_mask, pairs_segment, s, e | |||
| def _get_rawvideo_dec(self, | |||
| video_path, | |||
| rawVideoExtractor, | |||
| local_transform, | |||
| s=None, | |||
| e=None): | |||
| video_mask = np.zeros(self.max_frames, dtype=np.long) | |||
| max_video_length = 0 | |||
| # T x 3 x H x W | |||
| video = np.zeros((self.max_frames, 3, rawVideoExtractor.size, | |||
| rawVideoExtractor.size), | |||
| dtype=np.float) | |||
| if s is None: | |||
| start_time, end_time = None, None | |||
| else: | |||
| start_time = int(s) | |||
| end_time = int(e) | |||
| start_time = start_time if start_time >= 0. else 0. | |||
| end_time = end_time if end_time >= 0. else 0. | |||
| if start_time > end_time: | |||
| start_time, end_time = end_time, start_time | |||
| elif start_time == end_time: | |||
| end_time = end_time + 1 | |||
| if exists(video_path): | |||
| from decord import VideoReader, cpu | |||
| vreader = VideoReader(video_path, ctx=cpu(0)) | |||
| else: | |||
| logger.error('non video input, output is wrong!!!') | |||
| return video, video_mask | |||
| fps = vreader.get_avg_fps() | |||
| f_start = 0 if start_time is None else int(start_time * fps) | |||
| f_end = int( | |||
| min(1000000000 if end_time is None else end_time * fps, | |||
| len(vreader) - 1)) | |||
| num_frames = f_end - f_start + 1 | |||
| if num_frames > 0: | |||
| # L x T x 3 x H x W | |||
| sample_fps = int(self.feature_framerate) | |||
| t_stride = int(round(float(fps) / sample_fps)) | |||
| all_pos = list(range(f_start, f_end + 1, t_stride)) | |||
| if len(all_pos) > self.max_frames: | |||
| sample_pos = [ | |||
| all_pos[_] for _ in np.linspace( | |||
| 0, len(all_pos) - 1, num=self.max_frames, dtype=int) | |||
| ] | |||
| else: | |||
| sample_pos = all_pos | |||
| patch_images = [ | |||
| Image.fromarray(f) | |||
| for f in vreader.get_batch(sample_pos).asnumpy() | |||
| ] | |||
| patch_images = torch.stack( | |||
| [local_transform(img) for img in patch_images]) | |||
| slice_len = patch_images.shape[0] | |||
| max_video_length = max_video_length if max_video_length > slice_len else slice_len | |||
| if slice_len < 1: | |||
| pass | |||
| else: | |||
| video[:slice_len, ...] = patch_images | |||
| else: | |||
| logger.error('video path: {} error. video id: {}'.format( | |||
| video_path, video_id)) | |||
| video_mask[:max_video_length] = [1] * max_video_length | |||
| return video, video_mask | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| from modelscope.outputs import OutputKeys | |||
| output = {} | |||
| if 'video' in input and input['video'] is not None: | |||
| video_path = input['video'] | |||
| video, video_mask = self._get_rawvideo_dec(video_path, | |||
| self.rawVideoExtractor, | |||
| self.local_transform) | |||
| video = torch.unsqueeze( | |||
| torch.from_numpy(video), dim=0).to(self.device) | |||
| video_mask = torch.unsqueeze( | |||
| torch.from_numpy(video_mask), dim=0).to(self.device) | |||
| if 'text' in input and input['text'] is not None: | |||
| caption = input['text'] | |||
| pairs_text, pairs_mask, pairs_segment, s, e = self._get_text( | |||
| caption, self.tokenizer, enable_zh=False) | |||
| input_ids = torch.unsqueeze( | |||
| torch.from_numpy(pairs_text), dim=0).to(self.device) | |||
| input_mask = torch.unsqueeze( | |||
| torch.from_numpy(pairs_mask), dim=0).to(self.device) | |||
| segment_ids = torch.unsqueeze( | |||
| torch.from_numpy(pairs_segment), dim=0).to(self.device) | |||
| sequence_output, visual_output = self.model.get_sequence_visual_output( | |||
| input_ids, segment_ids, input_mask, video, video_mask) | |||
| logger.info('text feature: {}'.format(sequence_output[0][0][0])) | |||
| logger.info('video feature: {}'.format(visual_output[0][0][0])) | |||
| output[OutputKeys.VIDEO_EMBEDDING] = visual_output | |||
| output[OutputKeys.TEXT_EMBEDDING] = sequence_output | |||
| return output | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,42 @@ | |||
| import numpy as np | |||
| def get_retrieved_videos(sims, k): | |||
| """ | |||
| Returns list of retrieved top k videos based on the sims matrix | |||
| Args: | |||
| sims: similar matrix. | |||
| K: top k number of videos | |||
| """ | |||
| argm = np.argsort(-sims, axis=1) | |||
| topk = argm[:, :k].reshape(-1) | |||
| retrieved_videos = np.unique(topk) | |||
| return retrieved_videos | |||
| def get_index_to_normalize(sims, videos): | |||
| """ | |||
| Returns list of indices to normalize from sims based on videos | |||
| Args: | |||
| sims: similar matrix. | |||
| videos: video array. | |||
| """ | |||
| argm = np.argsort(-sims, axis=1)[:, 0] | |||
| result = np.array(list(map(lambda x: x in videos, argm))) | |||
| result = np.nonzero(result) | |||
| return result | |||
| def qb_norm(train_test, test_test, args): | |||
| k = args.get('k', 1) | |||
| beta = args.get('beta', 20) | |||
| retrieved_videos = get_retrieved_videos(train_test, k) | |||
| test_test_normalized = test_test | |||
| train_test = np.exp(train_test * beta) | |||
| test_test = np.exp(test_test * beta) | |||
| normalizing_sum = np.sum(train_test, axis=0) | |||
| index_for_normalizing = get_index_to_normalize(test_test, retrieved_videos) | |||
| test_test_normalized[index_for_normalizing, :] = \ | |||
| np.divide(test_test[index_for_normalizing, :], normalizing_sum) | |||
| return test_test_normalized | |||
| @@ -0,0 +1,508 @@ | |||
| import os | |||
| import platform | |||
| from collections import OrderedDict | |||
| from types import SimpleNamespace | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |||
| from modelscope.models.multi_modal.mmr.models.module_clip import ( | |||
| _PT_NAME, CLIP, QuickGELU, convert_weights) | |||
| from modelscope.models.multi_modal.mmr.models.module_cross import \ | |||
| Transformer as TransformerClip | |||
| from modelscope.models.multi_modal.mmr.models.until_module import (AllGather, | |||
| CrossEn, | |||
| LayerNorm) | |||
| from modelscope.utils.logger import get_logger | |||
| allgather = AllGather.apply | |||
| logger = get_logger() | |||
| __all__ = ['CLIP4Clip'] | |||
| class CLIP4Clip(nn.Module): | |||
| def __init__(self, config): | |||
| super(CLIP4Clip, self).__init__() | |||
| self.config = config | |||
| self.loose_type = config['loose_type'] | |||
| self.sim_header = config['sim_header'] | |||
| if self.sim_header in [ | |||
| 'tightTransf', 'tightFc1', 'tightFc2', 'tightFc3', 'tightFc4', | |||
| 'tightMean', 'tightFc5' | |||
| ]: | |||
| assert self.loose_type is False | |||
| backbone = config['pretrained_clip_name'] | |||
| # fix backbone without downlond | |||
| model_path = '{}/ViT-B-16.pt'.format(config['model_dir']) | |||
| if not os.path.exists(model_path): | |||
| logger.info('no model loaded!!!') | |||
| try: | |||
| # loading JIT archive | |||
| model = torch.jit.load(model_path, map_location='cpu').eval() | |||
| state_dict = model.state_dict() | |||
| except RuntimeError: | |||
| state_dict = torch.load(model_path, map_location='cpu') | |||
| vision_width = state_dict['visual.conv1.weight'].shape[0] | |||
| vision_layers = len([ | |||
| k for k in state_dict.keys() | |||
| if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') | |||
| ]) | |||
| vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] | |||
| grid_size = round( | |||
| (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) | |||
| image_resolution = vision_patch_size * grid_size | |||
| embed_dim = state_dict['text_projection'].shape[1] | |||
| context_length = state_dict['positional_embedding'].shape[0] | |||
| vocab_size = state_dict['token_embedding.weight'].shape[0] | |||
| transformer_width = state_dict['ln_final.weight'].shape[0] | |||
| transformer_heads = transformer_width // 64 | |||
| transformer_layers = len( | |||
| set( | |||
| k.split('.')[2] for k in state_dict | |||
| if k.startswith('transformer.resblocks'))) | |||
| cut_top_layer = 0 | |||
| self.clip = CLIP( | |||
| embed_dim, | |||
| image_resolution, | |||
| vision_layers - cut_top_layer, | |||
| vision_width, | |||
| vision_patch_size, | |||
| context_length, | |||
| vocab_size, | |||
| transformer_width, | |||
| transformer_heads, | |||
| transformer_layers - cut_top_layer, | |||
| linear_patch=config['linear_patch'], | |||
| use_gc=config['use_gc']).float() | |||
| if (platform.system() != 'Darwin'): | |||
| convert_weights(self.clip) # fp16 | |||
| if backbone in ['ViT-B/32', 'ViT-B/16']: | |||
| cross_config = SimpleNamespace(**{ | |||
| 'hidden_size': 512, | |||
| 'max_position_embeddings': 128, | |||
| }) | |||
| elif backbone in ['ViT-L/14', 'ViT-B/14-336px']: | |||
| cross_config = SimpleNamespace(**{ | |||
| 'hidden_size': 768, | |||
| 'max_position_embeddings': 128, | |||
| }) | |||
| else: | |||
| raise ValueError | |||
| cross_config.max_position_embeddings = context_length | |||
| self.cross_config = cross_config | |||
| self.text_weight_fc = nn.Sequential( | |||
| nn.Linear(transformer_width, transformer_width), | |||
| nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) | |||
| self.video_weight_fc = nn.Sequential( | |||
| nn.Linear(transformer_width, transformer_width), | |||
| nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) | |||
| if self.loose_type is False: | |||
| raise NotImplementedError | |||
| if self.sim_header in ['seqLSTM', 'seqTransf', 'tightFc1']: | |||
| self.frame_position_embeddings = nn.Embedding( | |||
| cross_config.max_position_embeddings, cross_config.hidden_size) | |||
| if self.sim_header in ['seqTransf', 'tightFc1']: | |||
| self.transformerClip = TransformerClip( | |||
| width=transformer_width, | |||
| layers=config['cross_num_hidden_layers'], | |||
| heads=transformer_heads, | |||
| ) | |||
| if self.sim_header == 'seqLSTM': | |||
| self.lstm_visual = nn.LSTM( | |||
| input_size=cross_config.hidden_size, | |||
| hidden_size=cross_config.hidden_size, | |||
| batch_first=True, | |||
| bidirectional=False, | |||
| num_layers=1) | |||
| self.loss_fct = CrossEn(config) | |||
| self.apply(self.init_weights) | |||
| self.clip.load_state_dict(state_dict, strict=False) | |||
| # ===> Initialization trick [HARD CODE] | |||
| if backbone not in _PT_NAME: | |||
| raise NotImplementedError | |||
| # reload | |||
| else: | |||
| if config['linear_patch'] == '3d': | |||
| raise NotImplementedError | |||
| new_state_dict = OrderedDict() | |||
| if self.sim_header == 'tightTransf': | |||
| raise NotImplementedError | |||
| if self.sim_header in ['seqLSTM', 'seqTransf', 'seqFc1']: | |||
| contain_frame_position = False | |||
| for key in state_dict.keys(): | |||
| if key.find('frame_position_embeddings') > -1: | |||
| contain_frame_position = True | |||
| break | |||
| if contain_frame_position is False: | |||
| for key, val in state_dict.items(): | |||
| if key == 'positional_embedding': | |||
| new_state_dict[ | |||
| 'frame_position_embeddings.weight'] = val.clone() | |||
| continue | |||
| if self.sim_header in [ | |||
| 'seqTransf', 'seqFc1' | |||
| ] and key.find('transformer.resblocks') == 0: | |||
| num_layer = int(key.split('.')[2]) | |||
| # cut from beginning | |||
| if num_layer < config['cross_num_hidden_layers']: | |||
| new_state_dict[key.replace( | |||
| 'transformer.', | |||
| 'transformerClip.')] = val.clone() | |||
| continue | |||
| # <=== End of initialization trick | |||
| self.load_state_dict( | |||
| new_state_dict, strict=False | |||
| ) # only update new state (seqTransf/seqLSTM/tightTransf) | |||
| if self.sim_header == 'tightFc5': | |||
| raise ValueError | |||
| def forward(self, | |||
| input_ids, | |||
| token_type_ids, | |||
| attention_mask, | |||
| video, | |||
| video_mask=None): | |||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||
| # B x T x 3 x H x W - > (B x T) x 3 x H x W | |||
| video = torch.as_tensor(video).float() | |||
| if len(video.shape) == 6: # image | |||
| b, bs, ts, channel, h, w = video.shape | |||
| b = b * bs | |||
| else: # video | |||
| b, ts, channel, h, w = video.shape | |||
| video = video.view(b * ts, channel, h, w) | |||
| sequence_output, visual_output = self.get_sequence_visual_output( | |||
| input_ids, | |||
| token_type_ids, | |||
| attention_mask, | |||
| video, | |||
| video_mask, | |||
| shaped=True) | |||
| if self.training: | |||
| loss = 0. | |||
| sim_matrix1, sim_matrix2, barlow_loss = self.get_similarity_logits( | |||
| sequence_output, | |||
| visual_output, | |||
| attention_mask, | |||
| video_mask, | |||
| shaped=True, | |||
| loose_type=self.loose_type) | |||
| sim_loss = (self.loss_fct(sim_matrix1) | |||
| + self.loss_fct(sim_matrix2)) / 2 | |||
| loss += sim_loss + barlow_loss * self.config.cdcr_lambda | |||
| return loss | |||
| else: | |||
| return None | |||
| def get_sequence_output(self, | |||
| input_ids, | |||
| token_type_ids, | |||
| attention_mask, | |||
| shaped=False): | |||
| if shaped is False: | |||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||
| bs_pair = input_ids.size(0) | |||
| sequence_hidden = self.clip.encode_text( | |||
| input_ids, return_hidden=True, prompt=None)[1].float() | |||
| sequence_hidden = sequence_hidden.view(bs_pair, -1, | |||
| sequence_hidden.size(-1)) | |||
| return sequence_hidden | |||
| def get_visual_output(self, video, video_mask, shaped=False): | |||
| if shaped is False: | |||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||
| video = torch.as_tensor(video).float() | |||
| b, ts, channel, h, w = video.shape | |||
| video = video.view(b * ts, channel, h, w) | |||
| bs_pair = video_mask.size(0) | |||
| visual_hidden = self.clip.encode_image(video).float() | |||
| visual_hidden = visual_hidden.float().view(bs_pair, -1, | |||
| visual_hidden.size(-1)) | |||
| return visual_hidden | |||
| def get_sequence_visual_output(self, | |||
| input_ids, | |||
| token_type_ids, | |||
| attention_mask, | |||
| video, | |||
| video_mask, | |||
| shaped=False): | |||
| if shaped is False: | |||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||
| video = torch.as_tensor(video).float() | |||
| if len(video.shape) == 6: # image | |||
| b, bs, ts, channel, h, w = video.shape | |||
| b = b * bs | |||
| else: # video | |||
| b, ts, channel, h, w = video.shape | |||
| video = video.view(b * ts, channel, h, w) | |||
| sequence_output = self.get_sequence_output( | |||
| input_ids, token_type_ids, attention_mask, shaped=True) | |||
| visual_output = self.get_visual_output(video, video_mask, shaped=True) | |||
| return sequence_output, visual_output | |||
| def agg_video_feat(self, visual_output, video_mask, sim_header='meanP'): | |||
| if self.config.max_sum == 0: | |||
| raise ValueError | |||
| if sim_header == 'meanP': | |||
| # Default: Parameter-free type | |||
| pass | |||
| elif sim_header == 'seqLSTM': | |||
| # Sequential type: LSTM | |||
| visual_output_original = visual_output | |||
| visual_output = pack_padded_sequence( | |||
| visual_output, | |||
| torch.sum(video_mask, dim=-1).cpu(), | |||
| batch_first=True, | |||
| enforce_sorted=False) | |||
| visual_output, _ = self.lstm_visual(visual_output) | |||
| if self.training: | |||
| self.lstm_visual.flatten_parameters() | |||
| visual_output, _ = pad_packed_sequence( | |||
| visual_output, batch_first=True) | |||
| visual_output = torch.cat( | |||
| (visual_output, visual_output_original[:, | |||
| visual_output.size(1):, | |||
| ...].contiguous()), | |||
| dim=1) | |||
| visual_output = visual_output + visual_output_original | |||
| elif sim_header == 'seqTransf': | |||
| # Sequential type: Transformer Encoder | |||
| visual_output_original = visual_output | |||
| seq_length = visual_output.size(1) | |||
| position_ids = torch.arange( | |||
| seq_length, dtype=torch.long, device=visual_output.device) | |||
| position_ids = position_ids.unsqueeze(0).expand( | |||
| visual_output.size(0), -1) | |||
| frame_position_embeddings = self.frame_position_embeddings( | |||
| position_ids) | |||
| visual_output = visual_output + frame_position_embeddings | |||
| extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 | |||
| extended_video_mask = extended_video_mask.expand( | |||
| -1, video_mask.size(1), -1) | |||
| visual_output = visual_output.permute(1, 0, 2) # NLD -> LND | |||
| visual_output = self.transformerClip(visual_output, | |||
| extended_video_mask) | |||
| visual_output = visual_output.permute(1, 0, 2) # LND -> NLD | |||
| visual_output = visual_output + visual_output_original | |||
| return visual_output | |||
| def wti_interaction(self, text_feat, video_feat, text_mask, video_mask): | |||
| text_weight = self.text_weight_fc(text_feat).squeeze( | |||
| 2) # B x N_t x D -> B x N_t | |||
| text_weight.masked_fill_( | |||
| torch.tensor((1 - text_mask), dtype=torch.bool), float('-inf')) | |||
| text_weight = torch.softmax(text_weight, dim=-1) # B x N_t | |||
| video_weight = self.video_weight_fc(video_feat).squeeze( | |||
| 2) # B x N_v x D -> B x N_v | |||
| video_weight.masked_fill_( | |||
| torch.tensor((1 - video_mask), dtype=torch.bool), float('-inf')) | |||
| video_weight = torch.softmax(video_weight, dim=-1) # B x N_v | |||
| text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) | |||
| video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) | |||
| retrieve_logits = torch.einsum('atd,bvd->abtv', | |||
| [text_feat, video_feat]) | |||
| retrieve_logits = torch.einsum('abtv,at->abtv', | |||
| [retrieve_logits, text_mask]) | |||
| retrieve_logits = torch.einsum('abtv,bv->abtv', | |||
| [retrieve_logits, video_mask]) | |||
| t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt | |||
| t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) | |||
| v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv | |||
| v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) | |||
| retrieve_logits = (t2v_logits + v2t_logits) / 2.0 | |||
| if self.training: | |||
| logit_scale = self.clip.logit_scale.exp() | |||
| retrieve_logits = logit_scale * retrieve_logits | |||
| # selecet max | |||
| max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), | |||
| torch.arange(max_idx1.shape[1])] | |||
| max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), | |||
| torch.arange(max_idx2.shape[1])] | |||
| max_t_feat = text_feat[torch.arange(max_idx2.shape[0]). | |||
| repeat_interleave(max_idx2.shape[1]), | |||
| max_idx2.flatten()].squeeze(1) | |||
| max_v_feat = video_feat[torch.arange(max_idx1.shape[0]). | |||
| repeat_interleave(max_idx1.shape[1]), | |||
| max_idx1.flatten()].squeeze(1) | |||
| t_feat = text_feat.reshape(-1, text_feat.shape[-1]) | |||
| t_mask = text_mask.flatten().type(torch.bool) | |||
| v_feat = video_feat.reshape(-1, video_feat.shape[-1]) | |||
| v_mask = video_mask.flatten().type(torch.bool) | |||
| t_feat = t_feat[t_mask] | |||
| v_feat = v_feat[v_mask] | |||
| max_t_feat = max_t_feat[v_mask] | |||
| max_v_feat = max_v_feat[t_mask] | |||
| text_weight = text_weight.flatten()[t_mask] | |||
| video_weight = video_weight.flatten()[v_mask] | |||
| z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (BxN_t)xD | |||
| z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std( | |||
| 0) # (BxN_t)xD | |||
| x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (BxN_v)xD | |||
| x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std( | |||
| 0) # (BxN_v)xD | |||
| # cross-correlation matrix | |||
| N, D = z_a_norm.shape | |||
| B = text_feat.shape[0] | |||
| c1 = torch.einsum('acd,a->cd', | |||
| torch.einsum('ac,ad->acd', z_a_norm, z_b_norm), | |||
| text_weight) / B # DxD | |||
| c2 = torch.einsum('acd,a->cd', | |||
| torch.einsum('ac,ad->acd', x_a_norm, x_b_norm), | |||
| video_weight) / B # DxD | |||
| c = (c1 + c2) / 2.0 | |||
| # loss | |||
| on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() | |||
| off_diag = c.flatten()[1:].view(D - 1, D + 1)[:, :-1].pow_(2).sum() | |||
| cdcr_loss = ( | |||
| on_diag * self.config.cdcr_alpha1 | |||
| + off_diag * self.config.cdcr_alpha2) | |||
| return retrieve_logits, retrieve_logits.T, cdcr_loss | |||
| else: | |||
| return retrieve_logits, retrieve_logits.T | |||
| def _loose_similarity(self, | |||
| sequence_output, | |||
| visual_output, | |||
| attention_mask, | |||
| video_mask, | |||
| sim_header='seqTransf'): | |||
| sequence_output, visual_output = sequence_output.contiguous( | |||
| ), visual_output.contiguous() | |||
| visual_output = self.agg_video_feat(visual_output, video_mask, | |||
| sim_header) | |||
| if self.training: # batch merge here | |||
| visual_output = allgather(visual_output, self.config) | |||
| attention_mask = allgather(attention_mask, self.config) | |||
| video_mask = allgather(video_mask, self.config) | |||
| sequence_output = allgather(sequence_output, self.config) | |||
| torch.distributed.barrier() # force sync | |||
| return self.wti_interaction(sequence_output, visual_output, | |||
| attention_mask, video_mask) | |||
| def get_similarity_logits(self, | |||
| sequence_output, | |||
| visual_output, | |||
| attention_mask, | |||
| video_mask, | |||
| shaped=False, | |||
| loose_type=False): | |||
| if shaped is False: | |||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||
| if loose_type: | |||
| assert self.sim_header in ['meanP', 'seqLSTM', 'seqTransf'] | |||
| if self.training: | |||
| retrieve_logits1, retrieve_logits2, barlow_loss = self._loose_similarity( | |||
| sequence_output, | |||
| visual_output, | |||
| attention_mask, | |||
| video_mask, | |||
| sim_header=self.sim_header) | |||
| return retrieve_logits1, retrieve_logits2, barlow_loss | |||
| else: | |||
| retrieve_logits1, retrieve_logits2 = self._loose_similarity( | |||
| sequence_output, | |||
| visual_output, | |||
| attention_mask, | |||
| video_mask, | |||
| sim_header=self.sim_header) | |||
| return retrieve_logits1, retrieve_logits2 | |||
| else: | |||
| raise NotImplementedError | |||
| @property | |||
| def dtype(self): | |||
| """ | |||
| :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||
| """ | |||
| try: | |||
| return next(self.parameters()).dtype | |||
| except StopIteration: | |||
| # For nn.DataParallel compatibility in PyTorch 1.5 | |||
| def find_tensor_attributes(module: nn.Module): | |||
| tuples = [(k, v) for k, v in module.__dict__.items() | |||
| if torch.is_tensor(v)] | |||
| return tuples | |||
| gen = self._named_members(get_members_fn=find_tensor_attributes) | |||
| first_tuple = next(gen) | |||
| return first_tuple[1].dtype | |||
| def init_weights(self, module): | |||
| """ Initialize the weights. | |||
| """ | |||
| if isinstance(module, (nn.Linear, nn.Embedding)): | |||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||
| module.weight.data.normal_(mean=0.0, std=0.02) | |||
| elif isinstance(module, LayerNorm): | |||
| if 'beta' in dir(module) and 'gamma' in dir(module): | |||
| module.beta.data.zero_() | |||
| module.gamma.data.fill_(1.0) | |||
| else: | |||
| module.bias.data.zero_() | |||
| module.weight.data.fill_(1.0) | |||
| if isinstance(module, nn.Linear) and module.bias is not None: | |||
| module.bias.data.zero_() | |||
| @@ -0,0 +1,526 @@ | |||
| # Part of the implementation is borrowed and modified from The OpenAI CLIP project. | |||
| import hashlib | |||
| import os | |||
| import urllib | |||
| import warnings | |||
| from collections import OrderedDict | |||
| from typing import Tuple, Union | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as checkpoint | |||
| from torch import nn | |||
| from tqdm import tqdm | |||
| _MODELS = {} | |||
| _PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'} | |||
| def available_models(): | |||
| """Returns the names of available CLIP models""" | |||
| return list(_MODELS.keys()) | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1): | |||
| super(Bottleneck, self).__init__() | |||
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(planes) | |||
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||
| self.bn2 = nn.BatchNorm2d(planes) | |||
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = None | |||
| self.stride = stride | |||
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||
| self.downsample = nn.Sequential( | |||
| OrderedDict([('-1', nn.AvgPool2d(stride)), | |||
| ('0', | |||
| nn.Conv2d( | |||
| inplanes, | |||
| planes * self.expansion, | |||
| 1, | |||
| stride=1, | |||
| bias=False)), | |||
| ('1', nn.BatchNorm2d(planes * self.expansion))])) | |||
| def forward(self, x: torch.Tensor): | |||
| identity = x | |||
| out = self.relu(self.bn1(self.conv1(x))) | |||
| out = self.relu(self.bn2(self.conv2(out))) | |||
| out = self.avgpool(out) | |||
| out = self.bn3(self.conv3(out)) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class AttentionPool2d(nn.Module): | |||
| def __init__(self, | |||
| spacial_dim: int, | |||
| embed_dim: int, | |||
| num_heads: int, | |||
| output_dim: int = None): | |||
| super(AttentionPool2d, self).__init__() | |||
| self.positional_embedding = nn.Parameter( | |||
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||
| self.num_heads = num_heads | |||
| def forward(self, x): | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| x.shape[2] * x.shape[3]).permute(2, 0, | |||
| 1) # NCHW -> (HW)NC | |||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||
| x, _ = F.multi_head_attention_forward( | |||
| query=x, | |||
| key=x, | |||
| value=x, | |||
| embed_dim_to_check=x.shape[-1], | |||
| num_heads=self.num_heads, | |||
| q_proj_weight=self.q_proj.weight, | |||
| k_proj_weight=self.k_proj.weight, | |||
| v_proj_weight=self.v_proj.weight, | |||
| in_proj_weight=None, | |||
| in_proj_bias=torch.cat( | |||
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||
| bias_k=None, | |||
| bias_v=None, | |||
| add_zero_attn=False, | |||
| dropout_p=0, | |||
| out_proj_weight=self.c_proj.weight, | |||
| out_proj_bias=self.c_proj.bias, | |||
| use_separate_proj_weight=True, | |||
| training=self.training, | |||
| need_weights=False) | |||
| return x[0] | |||
| class ModifiedResNet(nn.Module): | |||
| """ | |||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||
| - The final pooling layer is a QKV attention instead of an average pool | |||
| """ | |||
| def __init__(self, | |||
| layers, | |||
| output_dim, | |||
| heads, | |||
| input_resolution=224, | |||
| width=64): | |||
| super(ModifiedResNet, self).__init__() | |||
| self.output_dim = output_dim | |||
| self.input_resolution = input_resolution | |||
| # the 3-layer stem | |||
| self.conv1 = nn.Conv2d( | |||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||
| self.conv2 = nn.Conv2d( | |||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||
| self.conv3 = nn.Conv2d( | |||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d(width) | |||
| self.avgpool = nn.AvgPool2d(2) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| # residual layers | |||
| self._inplanes = width # this is a *mutable* variable used during construction | |||
| self.layer1 = self._make_layer(width, layers[0]) | |||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||
| embed_dim = width * 32 # the ResNet feature dimension | |||
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||
| heads, output_dim) | |||
| def _make_layer(self, planes, blocks, stride=1): | |||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||
| self._inplanes = planes * Bottleneck.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append(Bottleneck(self._inplanes, planes)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| def stem(x): | |||
| for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||
| (self.conv3, self.bn3)]: | |||
| x = self.relu(bn(conv(x))) | |||
| x = self.avgpool(x) | |||
| return x | |||
| x = x.type(self.conv1.weight.dtype) | |||
| x = stem(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| x = self.attnpool(x) | |||
| return x | |||
| class LayerNorm(nn.LayerNorm): | |||
| """Subclass torch's LayerNorm to handle fp16.""" | |||
| def forward(self, x: torch.Tensor): | |||
| orig_type = x.dtype | |||
| ret = super().forward(x.type(torch.float32)) | |||
| return ret.type(orig_type) | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, d_model: int, n_head: int, attn_mask=None): | |||
| super(ResidualAttentionBlock, self).__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.attn_mask = attn_mask | |||
| def attention(self, x: torch.Tensor): | |||
| attn_mask_ = self.attn_mask | |||
| if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): | |||
| attn_mask_ = self.attn_mask(x.size(0)) # LND | |||
| attn_mask_ = attn_mask_.to( | |||
| dtype=x.dtype, device=x.device) if attn_mask_ is not None else None | |||
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |||
| def forward(self, x): | |||
| x = x + self.attention(self.ln_1(x)) | |||
| x = x + self.mlp(self.ln_2(x)) | |||
| return x | |||
| class Transformer(nn.Module): | |||
| def __init__(self, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| attn_mask=None, | |||
| use_gc=0): | |||
| super(Transformer, self).__init__() | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential(*[ | |||
| ResidualAttentionBlock(width, heads, attn_mask) | |||
| for _ in range(layers) | |||
| ]) | |||
| self.use_gc = use_gc | |||
| def forward(self, x: torch.Tensor): | |||
| if self.use_gc > 0: | |||
| for blk in self.resblocks: | |||
| x = checkpoint.checkpoint(blk, x) | |||
| return x | |||
| else: | |||
| return self.resblocks(x) | |||
| class VisualTransformer(nn.Module): | |||
| def __init__(self, | |||
| input_resolution: int, | |||
| patch_size: int, | |||
| width: int, | |||
| layers: int, | |||
| heads: int, | |||
| output_dim: int, | |||
| linear_patch: str = '2d', | |||
| use_gc: int = 0): | |||
| super(VisualTransformer, self).__init__() | |||
| self.input_resolution = input_resolution | |||
| self.output_dim = output_dim | |||
| self.conv1 = nn.Conv2d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=patch_size, | |||
| stride=patch_size, | |||
| bias=False) | |||
| scale = width**-0.5 | |||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
| (input_resolution // patch_size)**2 + 1, width)) | |||
| self.ln_pre = LayerNorm(width) | |||
| self.transformer = Transformer(width, layers, heads, use_gc=use_gc) | |||
| self.ln_post = LayerNorm(width) | |||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
| # For 3D | |||
| assert linear_patch in ['2d', '3d'] | |||
| self.linear_patch = linear_patch | |||
| if self.linear_patch == '3d': | |||
| self.conv2 = nn.Conv3d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=(3, patch_size, patch_size), | |||
| stride=(1, patch_size, patch_size), | |||
| padding=(1, 0, 0), | |||
| bias=False) | |||
| def forward(self, x: torch.Tensor, video_frame=-1): | |||
| if self.linear_patch == '3d': | |||
| assert video_frame != -1 | |||
| x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], | |||
| x.shape[-1]) | |||
| x_3d = x_3d.permute(0, 2, 1, 3, 4) | |||
| x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] | |||
| x_3d = x_3d.permute(0, 2, 1, 3, | |||
| 4) # shape = [*, frame, width, grid, grid] | |||
| x = x_3d.reshape( | |||
| -1, x_3d.shape[-3], x_3d.shape[-2], | |||
| x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] | |||
| else: | |||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| _x = self.class_embedding.to(x.dtype) + torch.zeros( | |||
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
| x = torch.cat([_x, x], dim=1) | |||
| x = x + self.positional_embedding.to(x.dtype) | |||
| x = self.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| return x | |||
| class CLIP(nn.Module): | |||
| def __init__( | |||
| self, | |||
| embed_dim: int, | |||
| # vision | |||
| image_resolution: int, | |||
| vision_layers: Union[Tuple[int, int, int, int], int], | |||
| vision_width: int, | |||
| vision_patch_size: int, | |||
| # text | |||
| context_length: int, | |||
| vocab_size: int, | |||
| transformer_width: int, | |||
| transformer_heads: int, | |||
| transformer_layers: int, | |||
| # vision linear of patch | |||
| linear_patch: str = '2d', | |||
| use_gc: int = 0): | |||
| super(CLIP, self).__init__() | |||
| self.context_length = context_length | |||
| if isinstance(vision_layers, (tuple, list)): | |||
| vision_heads = vision_width * 32 // 64 | |||
| self.visual = ModifiedResNet( | |||
| layers=vision_layers, | |||
| output_dim=embed_dim, | |||
| heads=vision_heads, | |||
| input_resolution=image_resolution, | |||
| width=vision_width) | |||
| else: | |||
| vision_heads = vision_width // 64 | |||
| self.visual = VisualTransformer( | |||
| input_resolution=image_resolution, | |||
| patch_size=vision_patch_size, | |||
| width=vision_width, | |||
| layers=vision_layers, | |||
| heads=vision_heads, | |||
| output_dim=embed_dim, | |||
| linear_patch=linear_patch, | |||
| use_gc=use_gc) | |||
| self.transformer = Transformer( | |||
| width=transformer_width, | |||
| layers=transformer_layers, | |||
| heads=transformer_heads, | |||
| attn_mask=self.build_attention_mask) | |||
| self.vocab_size = vocab_size | |||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||
| self.positional_embedding = nn.Parameter( | |||
| torch.empty(self.context_length, transformer_width)) | |||
| self.ln_final = LayerNorm(transformer_width) | |||
| self.text_projection = nn.Parameter( | |||
| torch.empty(transformer_width, embed_dim)) | |||
| self.logit_scale = nn.Parameter(torch.ones([])) | |||
| self.initialize_parameters() | |||
| def initialize_parameters(self): | |||
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |||
| nn.init.normal_(self.positional_embedding, std=0.01) | |||
| if isinstance(self.visual, ModifiedResNet): | |||
| if self.visual.attnpool is not None: | |||
| std = self.visual.attnpool.c_proj.in_features**-0.5 | |||
| nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||
| for resnet_block in [ | |||
| self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||
| self.visual.layer4 | |||
| ]: | |||
| for name, param in resnet_block.named_parameters(): | |||
| if name.endswith('bn3.weight'): | |||
| nn.init.zeros_(param) | |||
| proj_std = (self.transformer.width**-0.5) * ( | |||
| (2 * self.transformer.layers)**-0.5) | |||
| attn_std = self.transformer.width**-0.5 | |||
| fc_std = (2 * self.transformer.width)**-0.5 | |||
| for block in self.transformer.resblocks: | |||
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||
| if self.text_projection is not None: | |||
| nn.init.normal_( | |||
| self.text_projection, std=self.transformer.width**-0.5) | |||
| def build_attention_mask(self, context_length): | |||
| # lazily create causal attention mask, with full attention between the vision tokens | |||
| # pytorch uses additive attention mask; fill with -inf | |||
| mask = torch.zeros(context_length, context_length) | |||
| mask.fill_(float('-inf')) | |||
| mask.triu_(1) # zero out the lower diagonal | |||
| return mask | |||
| @property | |||
| def dtype(self): | |||
| return self.visual.conv1.weight.dtype | |||
| def encode_image(self, image, return_hidden=False): | |||
| hidden = self.visual(image.type(self.dtype)) | |||
| hidden = self.visual.ln_post(hidden) @ self.visual.proj | |||
| x = hidden[:, 0, :] | |||
| if return_hidden: | |||
| return x, hidden | |||
| return x | |||
| def encode_text(self, text, return_hidden=False, prompt=None): | |||
| x = self.token_embedding(text).type( | |||
| self.dtype) # [batch_size, n_ctx, d_model] | |||
| if prompt: | |||
| x = prompt(x) | |||
| pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) | |||
| x = x + pos_emd | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| hidden = self.ln_final(x).type(self.dtype) @ self.text_projection | |||
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |||
| x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] | |||
| if return_hidden: | |||
| return x, hidden | |||
| return x | |||
| def forward(self, image, text): | |||
| image_features = self.encode_image(image) | |||
| text_features = self.encode_text(text) | |||
| # normalized features | |||
| image_features = image_features / image_features.norm( | |||
| dim=-1, keepdim=True) | |||
| text_features = text_features / text_features.norm( | |||
| dim=-1, keepdim=True) | |||
| # cosine similarity as logits | |||
| logit_scale = self.logit_scale.exp() | |||
| logits_per_image = logit_scale * image_features @ text_features.t() | |||
| logits_per_text = logit_scale * text_features @ image_features.t() | |||
| return logits_per_image, logits_per_text | |||
| def convert_weights(model: nn.Module): | |||
| """Convert applicable model parameters to fp16""" | |||
| def _convert_weights_to_fp16(lay): | |||
| # l = lay | |||
| if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |||
| lay.weight.data = lay.weight.data.half() | |||
| if lay.bias is not None: | |||
| lay.bias.data = lay.bias.data.half() | |||
| if isinstance(lay, nn.MultiheadAttention): | |||
| for attr in [ | |||
| *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], | |||
| 'in_proj_bias', 'bias_k', 'bias_v' | |||
| ]: | |||
| tensor = getattr(lay, attr) | |||
| if tensor is not None: | |||
| tensor.data = tensor.data.half() | |||
| for name in ['text_projection', 'proj']: | |||
| if hasattr(lay, name): | |||
| attr = getattr(lay, name) | |||
| if attr is not None: | |||
| attr.data = attr.data.half() | |||
| model.apply(_convert_weights_to_fp16) | |||
| @@ -0,0 +1,100 @@ | |||
| from __future__ import absolute_import, division, print_function | |||
| import logging | |||
| from collections import OrderedDict | |||
| import json | |||
| import torch | |||
| from torch import nn | |||
| from .until_module import ACT2FN, LayerNorm | |||
| logger = logging.getLogger(__name__) | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, d_model: int, n_head: int): | |||
| super().__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.n_head = n_head | |||
| def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): | |||
| attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) | |||
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |||
| def forward(self, para_tuple: tuple): | |||
| # x: torch.Tensor, attn_mask: torch.Tensor | |||
| x, attn_mask = para_tuple | |||
| x = x + self.attention(self.ln_1(x), attn_mask) | |||
| x = x + self.mlp(self.ln_2(x)) | |||
| return (x, attn_mask) | |||
| class Transformer(nn.Module): | |||
| def __init__(self, width: int, layers: int, heads: int): | |||
| super().__init__() | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential( | |||
| *[ResidualAttentionBlock(width, heads) for _ in range(layers)]) | |||
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |||
| return self.resblocks((x, attn_mask))[0] | |||
| class CrossEmbeddings(nn.Module): | |||
| """Construct the embeddings from word, position and token_type embeddings. | |||
| """ | |||
| def __init__(self, config): | |||
| super(CrossEmbeddings, self).__init__() | |||
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
| config.hidden_size) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def forward(self, concat_embeddings, concat_type=None): | |||
| _, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) | |||
| position_ids = torch.arange( | |||
| seq_length, dtype=torch.long, device=concat_embeddings.device) | |||
| position_ids = position_ids.unsqueeze(0).expand( | |||
| concat_embeddings.size(0), -1) | |||
| position_embeddings = self.position_embeddings(position_ids) | |||
| embeddings = concat_embeddings + position_embeddings # + token_type_embeddings | |||
| embeddings = self.dropout(embeddings) | |||
| return embeddings | |||
| class CrossPooler(nn.Module): | |||
| def __init__(self, config): | |||
| super(CrossPooler, self).__init__() | |||
| self.ln_pool = LayerNorm(config.hidden_size) | |||
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
| self.activation = QuickGELU() | |||
| def forward(self, hidden_states, hidden_mask): | |||
| # We "pool" the model by simply taking the hidden state corresponding | |||
| # to the first token. | |||
| hidden_states = self.ln_pool(hidden_states) | |||
| pooled_output = hidden_states[:, 0] | |||
| pooled_output = self.dense(pooled_output) | |||
| pooled_output = self.activation(pooled_output) | |||
| return pooled_output | |||
| @@ -0,0 +1,158 @@ | |||
| import gzip | |||
| import html | |||
| import os | |||
| from functools import lru_cache | |||
| import ftfy | |||
| import regex as re | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||
| The reversible bpe codes work on unicode strings. | |||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||
| """ | |||
| bs = list(range(ord('!'), | |||
| ord('~') + 1)) + list(range( | |||
| ord('¡'), | |||
| ord('¬') + 1)) + list(range(ord('®'), | |||
| ord('ÿ') + 1)) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2**8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2**8 + n) | |||
| n += 1 | |||
| cs = [chr(n) for n in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| def basic_clean(text): | |||
| text = ftfy.fix_text(text) | |||
| text = html.unescape(html.unescape(text)) | |||
| return text.strip() | |||
| def whitespace_clean(text): | |||
| text = re.sub(r'\s+', ' ', text) | |||
| text = text.strip() | |||
| return text | |||
| class SimpleTokenizer(object): | |||
| def __init__(self, model_dir): | |||
| bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir) | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||
| merges = merges[1:49152 - 256 - 2 + 1] | |||
| merges = [tuple(merge.split()) for merge in merges] | |||
| vocab = list(bytes_to_unicode().values()) | |||
| vocab = vocab + [v + '</w>' for v in vocab] | |||
| for merge in merges: | |||
| vocab.append(''.join(merge)) | |||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||
| self.cache = { | |||
| '<|startoftext|>': '<|startoftext|>', | |||
| '<|endoftext|>': '<|endoftext|>' | |||
| } | |||
| self.pat = re.compile( | |||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||
| re.IGNORECASE) | |||
| self.vocab = self.encoder | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token + '</w>' | |||
| while True: | |||
| bigram = min( | |||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except Exception: | |||
| new_word.extend(word[i:]) | |||
| break | |||
| if word[i] == first and i < len(word) - 1 and word[ | |||
| i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = ' '.join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def encode(self, text): | |||
| bpe_tokens = [] | |||
| text = whitespace_clean(basic_clean(text)).lower() | |||
| for token in re.findall(self.pat, text): | |||
| token = ''.join(self.byte_encoder[b] | |||
| for b in token.encode('utf-8')) | |||
| bpe_tokens.extend(self.encoder[bpe_token] | |||
| for bpe_token in self.bpe(token).split(' ')) | |||
| return bpe_tokens | |||
| def decode(self, tokens): | |||
| text = ''.join([self.decoder[token] for token in tokens]) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||
| return text | |||
| def tokenize(self, text): | |||
| tokens = [] | |||
| text = whitespace_clean(basic_clean(text)).lower() | |||
| for token in re.findall(self.pat, text): | |||
| token = ''.join(self.byte_encoder[b] | |||
| for b in token.encode('utf-8')) | |||
| tokens.extend( | |||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||
| return tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| return [self.encoder[bpe_token] for bpe_token in tokens] | |||
| @@ -0,0 +1,120 @@ | |||
| # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """PyTorch BERT model.""" | |||
| import logging | |||
| import math | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| logger = logging.getLogger(__name__) | |||
| def gelu(x): | |||
| """Implementation of the gelu activation function. | |||
| For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |||
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |||
| """ | |||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| def swish(x): | |||
| return x * torch.sigmoid(x) | |||
| ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} | |||
| class LayerNorm(nn.Module): | |||
| def __init__(self, hidden_size, eps=1e-12): | |||
| """Construct a layernorm module in the TF style (epsilon inside the square root). | |||
| """ | |||
| super(LayerNorm, self).__init__() | |||
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
| self.variance_epsilon = eps | |||
| def forward(self, x): | |||
| u = x.mean(-1, keepdim=True) | |||
| s = (x - u).pow(2).mean(-1, keepdim=True) | |||
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
| return self.weight * x + self.bias | |||
| class CrossEn(nn.Module): | |||
| def __init__(self, config=None): | |||
| super(CrossEn, self).__init__() | |||
| def forward(self, sim_matrix): | |||
| logpt = F.log_softmax(sim_matrix, dim=-1) | |||
| logpt = torch.diag(logpt) | |||
| nce_loss = -logpt | |||
| sim_loss = nce_loss.mean() | |||
| return sim_loss | |||
| class AllGather(torch.autograd.Function): | |||
| """An autograd function that performs allgather on a tensor.""" | |||
| @staticmethod | |||
| def forward(ctx, tensor, args): | |||
| if args.world_size == 1: | |||
| ctx.rank = args.local_rank | |||
| ctx.batch_size = tensor.shape[0] | |||
| return tensor | |||
| else: | |||
| output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |||
| torch.distributed.all_gather(output, tensor) | |||
| ctx.rank = args.local_rank | |||
| ctx.batch_size = tensor.shape[0] | |||
| return torch.cat(output, dim=0) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return ( | |||
| grad_output[ctx.batch_size * ctx.rank:ctx.batch_size | |||
| * (ctx.rank + 1)], | |||
| None, | |||
| ) | |||
| class AllGather2(torch.autograd.Function): | |||
| """An autograd function that performs allgather on a tensor.""" | |||
| # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 | |||
| @staticmethod | |||
| def forward(ctx, tensor, args): | |||
| if args.world_size == 1: | |||
| ctx.rank = args.local_rank | |||
| ctx.batch_size = tensor.shape[0] | |||
| return tensor | |||
| else: | |||
| output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |||
| torch.distributed.all_gather(output, tensor) | |||
| ctx.rank = args.local_rank | |||
| ctx.batch_size = tensor.shape[0] | |||
| return torch.cat(output, dim=0) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| grad_input = grad_output.clone() | |||
| torch.distributed.all_reduce( | |||
| grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) | |||
| return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) | |||
| * ctx.batch_size], None) | |||
| @@ -25,9 +25,9 @@ class BertForSequenceClassification(Model): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| import torch | |||
| from easynlp.appzoo import SequenceClassification | |||
| from easynlp.core.predictor import get_model_predictor | |||
| import torch | |||
| self.model = get_model_predictor( | |||
| model_dir=self.model_dir, | |||
| model_cls=SequenceClassification, | |||
| @@ -21,7 +21,8 @@ class PalmForTextGeneration(TorchModel): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | |||
| from sofa.models.palm_v2 import (PalmForConditionalGeneration, | |||
| Translator) | |||
| self.model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||
| self.tokenizer = self.model.tokenizer | |||
| self.generator = Translator(self.model) | |||
| @@ -27,7 +27,8 @@ class SpaceForDialogIntent(Model): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||
| from modelscope.trainers.nlp.space.trainer.intent_trainer import \ | |||
| IntentTrainer | |||
| self.model_dir = model_dir | |||
| self.config = kwargs.pop( | |||
| 'config', | |||
| @@ -22,7 +22,7 @@ class SpaceForDialogStateTracking(Model): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa.models.space import SpaceForDST, SpaceConfig | |||
| from sofa.models.space import SpaceConfig, SpaceForDST | |||
| self.model_dir = model_dir | |||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
| @@ -225,8 +225,8 @@ class MsDataset: | |||
| continue | |||
| retained_columns.append(k) | |||
| import torch | |||
| import math | |||
| import torch | |||
| class MsIterableDataset(torch.utils.data.IterableDataset): | |||
| @@ -74,6 +74,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.text_to_image_synthesis: | |||
| (Pipelines.text_to_image_synthesis, | |||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||
| Tasks.video_multi_modal_embedding: | |||
| (Pipelines.video_multi_modal_embedding, | |||
| 'damo/multi_modal_clip_vtretrival_msrvtt_53'), | |||
| Tasks.image_color_enhance: (Pipelines.image_color_enhance, | |||
| 'damo/cv_csrnet_image-color-enhance-models'), | |||
| Tasks.virtual_tryon: (Pipelines.virtual_tryon, | |||
| @@ -3,7 +3,10 @@ try: | |||
| from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | |||
| from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline | |||
| from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline | |||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||
| from .video_multi_modal_embedding_pipeline import \ | |||
| VideoMultiModalEmbeddingPipeline | |||
| from .visual_question_answering_pipeline import \ | |||
| VisualQuestionAnsweringPipeline | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| pass | |||
| @@ -0,0 +1,42 @@ | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Model, Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.video_multi_modal_embedding, | |||
| module_name=Pipelines.video_multi_modal_embedding) | |||
| class VideoMultiModalEmbeddingPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a video_multi_modal_embedding pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
| with self.place_device(): | |||
| out = self.forward(input) | |||
| self._check_output(out) | |||
| return out | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| return self.model(input) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -15,6 +15,7 @@ try: | |||
| from .dialog_modeling_pipeline import * # noqa F403 | |||
| from .dialog_state_tracking_pipeline import * # noqa F403 | |||
| from .fill_mask_pipeline import * # noqa F403 | |||
| from .named_entity_recognition_pipeline import * # noqa F403 | |||
| from .nli_pipeline import * # noqa F403 | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sentiment_classification_pipeline import * # noqa F403 | |||
| @@ -22,7 +23,6 @@ try: | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .word_segmentation_pipeline import * # noqa F403 | |||
| from .zero_shot_classification_pipeline import * # noqa F403 | |||
| from .named_entity_recognition_pipeline import * # noqa F403 | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| pass | |||
| @@ -25,7 +25,7 @@ class DialogStateTrackingPreprocessor(Preprocessor): | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa.models.space import SpaceTokenizer, SpaceConfig | |||
| from sofa.models.space import SpaceConfig, SpaceTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
| self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) | |||
| @@ -78,7 +78,8 @@ class SequenceClassificationTrainer(BaseTrainer): | |||
| import torch | |||
| from easynlp.appzoo import load_dataset | |||
| from easynlp.appzoo.dataset import GeneralDataset | |||
| from easynlp.appzoo.sequence_classification.model import SequenceClassification | |||
| from easynlp.appzoo.sequence_classification.model import \ | |||
| SequenceClassification | |||
| from easynlp.utils import losses | |||
| from sklearn.metrics import f1_score | |||
| from torch.utils.data import DataLoader | |||
| @@ -77,6 +77,7 @@ class MultiModalTasks(object): | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| generative_multi_modal_embedding = 'generative-multi-modal-embedding' | |||
| visual_question_answering = 'visual-question-answering' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||
| @@ -85,7 +86,6 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||
| Holds the standard task name to use for identifying different tasks. | |||
| This should be used to register models, pipelines, trainers. | |||
| """ | |||
| reverse_field_index = {} | |||
| @staticmethod | |||
| @@ -89,8 +89,8 @@ def get_model_type(model_dir): | |||
| def parse_label_mapping(model_dir): | |||
| import os | |||
| import json | |||
| import os | |||
| label2id = None | |||
| label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING) | |||
| if os.path.exists(label_path): | |||
| @@ -70,9 +70,9 @@ def parse_requirements(fname='requirements.txt', with_version=True): | |||
| CommandLine: | |||
| python -c "import setup; print(setup.parse_requirements())" | |||
| """ | |||
| import re | |||
| import sys | |||
| from os.path import exists | |||
| import re | |||
| require_fpath = fname | |||
| def parse_line(line): | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| class VideoMultiModalEmbeddingTest(unittest.TestCase): | |||
| model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' | |||
| video_path = 'data/test/videos/multi_modal_test_video_9770.mp4' | |||
| caption = ('a person is connecting something to system', None, None) | |||
| _input = {'video': video_path, 'text': caption} | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| pipeline_video_multi_modal_embedding = pipeline( | |||
| Tasks.video_multi_modal_embedding, model=self.model_id) | |||
| output = pipeline_video_multi_modal_embedding(self._input) | |||
| logger.info('text feature: {}'.format( | |||
| output['text_embedding'][0][0][0])) | |||
| logger.info('video feature: {}'.format( | |||
| output['video_embedding'][0][0][0])) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_video_multi_modal_embedding = pipeline( | |||
| task=Tasks.video_multi_modal_embedding) | |||
| output = pipeline_video_multi_modal_embedding(self._input) | |||
| logger.info('text feature: {}'.format( | |||
| output['text_embedding'][0][0][0])) | |||
| logger.info('video feature: {}'.format( | |||
| output['video_embedding'][0][0][0])) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||