Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9458640master
| @@ -124,3 +124,4 @@ replace.sh | |||||
| # Pytorch | # Pytorch | ||||
| *.pth | *.pth | ||||
| *.pt | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:33e21c16d5388684b61d7251b9d4e418f8146c3ba3fa400ebd8d913058687cfc | |||||
| size 431888 | |||||
| @@ -34,6 +34,7 @@ class Models(object): | |||||
| gemm = 'gemm-generative-multi-modal' | gemm = 'gemm-generative-multi-modal' | ||||
| mplug = 'mplug' | mplug = 'mplug' | ||||
| imagen = 'imagen-text-to-image-synthesis' | imagen = 'imagen-text-to-image-synthesis' | ||||
| video_clip = 'video-clip-multi-modal-embedding' | |||||
| class TaskModels(object): | class TaskModels(object): | ||||
| @@ -99,6 +100,7 @@ class Pipelines(object): | |||||
| generative_multi_modal_embedding = 'generative-multi-modal-embedding' | generative_multi_modal_embedding = 'generative-multi-modal-embedding' | ||||
| visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -5,10 +5,10 @@ from .base import Model | |||||
| from .builder import MODELS, build_model | from .builder import MODELS, build_model | ||||
| try: | try: | ||||
| from .audio.ans.frcrn import FRCRNModel | |||||
| from .audio.asr import GenericAutomaticSpeechRecognition | from .audio.asr import GenericAutomaticSpeechRecognition | ||||
| from .audio.tts import SambertHifigan | |||||
| from .audio.kws import GenericKeyWordSpotting | from .audio.kws import GenericKeyWordSpotting | ||||
| from .audio.ans.frcrn import FRCRNModel | |||||
| from .audio.tts import SambertHifigan | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| print(AUDIO_IMPORT_ERROR.format(e)) | print(AUDIO_IMPORT_ERROR.format(e)) | ||||
| @@ -29,8 +29,8 @@ try: | |||||
| SbertForZeroShotClassification, SpaceForDialogIntent, | SbertForZeroShotClassification, SpaceForDialogIntent, | ||||
| SpaceForDialogModeling, SpaceForDialogStateTracking, | SpaceForDialogModeling, SpaceForDialogStateTracking, | ||||
| StructBertForMaskedLM, VecoForMaskedLM) | StructBertForMaskedLM, VecoForMaskedLM) | ||||
| from .nlp.heads import (SequenceClassificationHead) | |||||
| from .nlp.backbones import (SbertModel) | |||||
| from .nlp.backbones import SbertModel | |||||
| from .nlp.heads import SequenceClassificationHead | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'pytorch'": | if str(e) == "No module named 'pytorch'": | ||||
| pass | pass | ||||
| @@ -1,6 +1,8 @@ | |||||
| from .clip.clip_model import CLIPForMultiModalEmbedding | from .clip.clip_model import CLIPForMultiModalEmbedding | ||||
| from .gemm.gemm_model import GEMMForMultiModalEmbedding | from .gemm.gemm_model import GEMMForMultiModalEmbedding | ||||
| from .imagen.imagen_model import ImagenForTextToImageSynthesis | from .imagen.imagen_model import ImagenForTextToImageSynthesis | ||||
| from .mmr.models.clip_for_multi_model_video_embedding import \ | |||||
| VideoCLIPForMultiModalEmbedding | |||||
| from .mplug_for_visual_question_answering import \ | from .mplug_for_visual_question_answering import \ | ||||
| MPlugForVisualQuestionAnswering | MPlugForVisualQuestionAnswering | ||||
| from .ofa_for_image_captioning_model import OfaForImageCaptioning | from .ofa_for_image_captioning_model import OfaForImageCaptioning | ||||
| @@ -784,7 +784,7 @@ class BertModel(nn.Module): | |||||
| elif config.transformer_type.lower() == 'act': | elif config.transformer_type.lower() == 'act': | ||||
| self.encoder = BERTEncoderACT(config) | self.encoder = BERTEncoderACT(config) | ||||
| elif config.transformer_type.lower() == 'textnas': | elif config.transformer_type.lower() == 'textnas': | ||||
| from textnas_final import op_dict, input_dict, skip_dict | |||||
| from textnas_final import input_dict, op_dict, skip_dict | |||||
| self.encoder = TextNASEncoder(config, op_dict, input_dict, | self.encoder = TextNASEncoder(config, op_dict, input_dict, | ||||
| skip_dict) | skip_dict) | ||||
| else: | else: | ||||
| @@ -0,0 +1,114 @@ | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch as th | |||||
| from PIL import Image | |||||
| from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, | |||||
| Normalize, Resize, ToTensor) | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class RawVideoExtractorCV2(): | |||||
| def __init__( | |||||
| self, | |||||
| centercrop=False, | |||||
| size=224, | |||||
| frame_rate=-1, | |||||
| ): | |||||
| self.centercrop = centercrop | |||||
| self.size = size | |||||
| self.framerate = frame_rate | |||||
| self.transform = self._transform(self.size) | |||||
| def _transform(self, n_px): | |||||
| return Compose([ | |||||
| Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||||
| CenterCrop(n_px), | |||||
| lambda image: image.convert('RGB'), | |||||
| ToTensor(), | |||||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| def video_to_tensor(self, | |||||
| video_file, | |||||
| preprocess, | |||||
| sample_fp=0, | |||||
| start_time=None, | |||||
| end_time=None): | |||||
| if start_time is not None or end_time is not None: | |||||
| assert isinstance(start_time, int) and isinstance(end_time, int) \ | |||||
| and start_time > -1 and end_time > start_time | |||||
| assert sample_fp > -1 | |||||
| # Samples a frame sample_fp X frames. | |||||
| cap = cv2.VideoCapture(video_file) | |||||
| frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||||
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |||||
| if fps == 0: | |||||
| logger.info(f'{video_file} with fps 0!!!') | |||||
| total_duration = (frameCount + fps - 1) // fps | |||||
| start_sec, end_sec = 0, total_duration | |||||
| if start_time is not None: | |||||
| start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration | |||||
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) | |||||
| interval = 1 | |||||
| if sample_fp > 0: | |||||
| interval = fps // sample_fp | |||||
| else: | |||||
| sample_fp = fps | |||||
| if interval == 0: | |||||
| interval = 1 | |||||
| inds = [ind for ind in np.arange(0, fps, interval)] | |||||
| assert len(inds) >= sample_fp | |||||
| inds = inds[:sample_fp] | |||||
| ret = True | |||||
| images = [] | |||||
| for sec in np.arange(start_sec, end_sec + 1): | |||||
| if not ret: | |||||
| break | |||||
| sec_base = int(sec * fps) | |||||
| for ind in inds: | |||||
| cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) | |||||
| ret, frame = cap.read() | |||||
| if not ret: | |||||
| break | |||||
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||||
| images.append( | |||||
| preprocess(Image.fromarray(frame_rgb).convert('RGB'))) | |||||
| cap.release() | |||||
| if len(images) > 0: | |||||
| video_data = th.tensor(np.stack(images)) | |||||
| else: | |||||
| video_data = th.zeros(1) | |||||
| return {'video': video_data} | |||||
| def get_video_data(self, video_path, start_time=None, end_time=None): | |||||
| image_input = self.video_to_tensor( | |||||
| video_path, | |||||
| self.transform, | |||||
| sample_fp=self.framerate, | |||||
| start_time=start_time, | |||||
| end_time=end_time) | |||||
| return image_input | |||||
| def process_raw_data(self, raw_video_data): | |||||
| tensor_size = raw_video_data.size() | |||||
| tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], | |||||
| tensor_size[-1]) | |||||
| return tensor | |||||
| # An ordinary video frame extractor based CV2 | |||||
| RawVideoExtractor = RawVideoExtractorCV2 | |||||
| @@ -0,0 +1,218 @@ | |||||
| import os | |||||
| import random | |||||
| from os.path import exists | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.multi_modal.mmr.dataloaders.rawvideo_util import \ | |||||
| RawVideoExtractor | |||||
| from modelscope.models.multi_modal.mmr.models.modeling import CLIP4Clip | |||||
| from modelscope.models.multi_modal.mmr.models.tokenization_clip import \ | |||||
| SimpleTokenizer as ClipTokenizer | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @MODELS.register_module( | |||||
| Tasks.video_multi_modal_embedding, module_name=Models.video_clip) | |||||
| class VideoCLIPForMultiModalEmbedding(Model): | |||||
| def __init__(self, model_dir, device_id=-1): | |||||
| super().__init__(model_dir=model_dir, device_id=device_id) | |||||
| # model config parameters | |||||
| with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: | |||||
| model_config = json.load(json_file) | |||||
| model_config = model_config['paras'] | |||||
| model_config['model_dir'] = model_dir | |||||
| self.SPECIAL_TOKEN = { | |||||
| 'CLS_TOKEN': '<|startoftext|>', | |||||
| 'SEP_TOKEN': '<|endoftext|>', | |||||
| 'MASK_TOKEN': '[MASK]', | |||||
| 'UNK_TOKEN': '[UNK]', | |||||
| 'PAD_TOKEN': '[PAD]' | |||||
| } | |||||
| self.max_words = model_config['max_words'] | |||||
| self.max_frames = model_config['max_frames'] | |||||
| self.feature_framerate = model_config['feature_framerate'] | |||||
| self.image_resolution = 224 | |||||
| self.device = model_config['device'] | |||||
| self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}' | |||||
| self.tokenizer = ClipTokenizer(model_dir) | |||||
| self.rawVideoExtractor = RawVideoExtractor( | |||||
| frame_rate=self.feature_framerate, size=self.image_resolution) | |||||
| self.local_transform = self.rawVideoExtractor.transform | |||||
| self.model = CLIP4Clip(model_config) | |||||
| if hasattr(self.model, 'module'): | |||||
| self.model = self.model.module.to(self.device) | |||||
| else: | |||||
| self.model = self.model.to(self.device) | |||||
| if self.init_model: | |||||
| assert exists(self.init_model) | |||||
| model_state_dict = torch.load(self.init_model, map_location='cpu') | |||||
| self.model.load_state_dict(model_state_dict, strict=False) | |||||
| self.model.to(self.device) | |||||
| def _get_text(self, caption, tokenizer, enable_zh=False): | |||||
| if len(caption) == 3: | |||||
| _caption_text, s, e = caption | |||||
| elif len(caption) == 4: | |||||
| _caption_text, s, e, pos = caption | |||||
| else: | |||||
| NotImplementedError | |||||
| if isinstance(_caption_text, list): | |||||
| caption_text = random.choice(_caption_text) | |||||
| else: | |||||
| caption_text = _caption_text | |||||
| if enable_zh: | |||||
| _token = tokenizer.encode(caption_text) | |||||
| input_ids = _token.ids | |||||
| input_mask = _token.attention_mask | |||||
| segment_ids = _token.type_ids | |||||
| else: | |||||
| words = tokenizer.tokenize(caption_text) | |||||
| words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words | |||||
| total_length_with_CLS = self.max_words - 1 | |||||
| if len(words) > total_length_with_CLS: | |||||
| words = words[:total_length_with_CLS] | |||||
| words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']] | |||||
| input_ids = tokenizer.convert_tokens_to_ids(words) | |||||
| input_mask = [1] * len(input_ids) | |||||
| segment_ids = [0] * len(input_ids) | |||||
| while len(input_ids) < self.max_words: | |||||
| input_ids.append(0) | |||||
| input_mask.append(0) | |||||
| segment_ids.append(0) | |||||
| assert len(input_ids) == self.max_words | |||||
| assert len(input_mask) == self.max_words | |||||
| assert len(segment_ids) == self.max_words | |||||
| pairs_text = np.array(input_ids) | |||||
| pairs_mask = np.array(input_mask) | |||||
| pairs_segment = np.array(segment_ids) | |||||
| return pairs_text, pairs_mask, pairs_segment, s, e | |||||
| def _get_rawvideo_dec(self, | |||||
| video_path, | |||||
| rawVideoExtractor, | |||||
| local_transform, | |||||
| s=None, | |||||
| e=None): | |||||
| video_mask = np.zeros(self.max_frames, dtype=np.long) | |||||
| max_video_length = 0 | |||||
| # T x 3 x H x W | |||||
| video = np.zeros((self.max_frames, 3, rawVideoExtractor.size, | |||||
| rawVideoExtractor.size), | |||||
| dtype=np.float) | |||||
| if s is None: | |||||
| start_time, end_time = None, None | |||||
| else: | |||||
| start_time = int(s) | |||||
| end_time = int(e) | |||||
| start_time = start_time if start_time >= 0. else 0. | |||||
| end_time = end_time if end_time >= 0. else 0. | |||||
| if start_time > end_time: | |||||
| start_time, end_time = end_time, start_time | |||||
| elif start_time == end_time: | |||||
| end_time = end_time + 1 | |||||
| if exists(video_path): | |||||
| from decord import VideoReader, cpu | |||||
| vreader = VideoReader(video_path, ctx=cpu(0)) | |||||
| else: | |||||
| logger.error('non video input, output is wrong!!!') | |||||
| return video, video_mask | |||||
| fps = vreader.get_avg_fps() | |||||
| f_start = 0 if start_time is None else int(start_time * fps) | |||||
| f_end = int( | |||||
| min(1000000000 if end_time is None else end_time * fps, | |||||
| len(vreader) - 1)) | |||||
| num_frames = f_end - f_start + 1 | |||||
| if num_frames > 0: | |||||
| # L x T x 3 x H x W | |||||
| sample_fps = int(self.feature_framerate) | |||||
| t_stride = int(round(float(fps) / sample_fps)) | |||||
| all_pos = list(range(f_start, f_end + 1, t_stride)) | |||||
| if len(all_pos) > self.max_frames: | |||||
| sample_pos = [ | |||||
| all_pos[_] for _ in np.linspace( | |||||
| 0, len(all_pos) - 1, num=self.max_frames, dtype=int) | |||||
| ] | |||||
| else: | |||||
| sample_pos = all_pos | |||||
| patch_images = [ | |||||
| Image.fromarray(f) | |||||
| for f in vreader.get_batch(sample_pos).asnumpy() | |||||
| ] | |||||
| patch_images = torch.stack( | |||||
| [local_transform(img) for img in patch_images]) | |||||
| slice_len = patch_images.shape[0] | |||||
| max_video_length = max_video_length if max_video_length > slice_len else slice_len | |||||
| if slice_len < 1: | |||||
| pass | |||||
| else: | |||||
| video[:slice_len, ...] = patch_images | |||||
| else: | |||||
| logger.error('video path: {} error. video id: {}'.format( | |||||
| video_path, video_id)) | |||||
| video_mask[:max_video_length] = [1] * max_video_length | |||||
| return video, video_mask | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| from modelscope.outputs import OutputKeys | |||||
| output = {} | |||||
| if 'video' in input and input['video'] is not None: | |||||
| video_path = input['video'] | |||||
| video, video_mask = self._get_rawvideo_dec(video_path, | |||||
| self.rawVideoExtractor, | |||||
| self.local_transform) | |||||
| video = torch.unsqueeze( | |||||
| torch.from_numpy(video), dim=0).to(self.device) | |||||
| video_mask = torch.unsqueeze( | |||||
| torch.from_numpy(video_mask), dim=0).to(self.device) | |||||
| if 'text' in input and input['text'] is not None: | |||||
| caption = input['text'] | |||||
| pairs_text, pairs_mask, pairs_segment, s, e = self._get_text( | |||||
| caption, self.tokenizer, enable_zh=False) | |||||
| input_ids = torch.unsqueeze( | |||||
| torch.from_numpy(pairs_text), dim=0).to(self.device) | |||||
| input_mask = torch.unsqueeze( | |||||
| torch.from_numpy(pairs_mask), dim=0).to(self.device) | |||||
| segment_ids = torch.unsqueeze( | |||||
| torch.from_numpy(pairs_segment), dim=0).to(self.device) | |||||
| sequence_output, visual_output = self.model.get_sequence_visual_output( | |||||
| input_ids, segment_ids, input_mask, video, video_mask) | |||||
| logger.info('text feature: {}'.format(sequence_output[0][0][0])) | |||||
| logger.info('video feature: {}'.format(visual_output[0][0][0])) | |||||
| output[OutputKeys.VIDEO_EMBEDDING] = visual_output | |||||
| output[OutputKeys.TEXT_EMBEDDING] = sequence_output | |||||
| return output | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -0,0 +1,42 @@ | |||||
| import numpy as np | |||||
| def get_retrieved_videos(sims, k): | |||||
| """ | |||||
| Returns list of retrieved top k videos based on the sims matrix | |||||
| Args: | |||||
| sims: similar matrix. | |||||
| K: top k number of videos | |||||
| """ | |||||
| argm = np.argsort(-sims, axis=1) | |||||
| topk = argm[:, :k].reshape(-1) | |||||
| retrieved_videos = np.unique(topk) | |||||
| return retrieved_videos | |||||
| def get_index_to_normalize(sims, videos): | |||||
| """ | |||||
| Returns list of indices to normalize from sims based on videos | |||||
| Args: | |||||
| sims: similar matrix. | |||||
| videos: video array. | |||||
| """ | |||||
| argm = np.argsort(-sims, axis=1)[:, 0] | |||||
| result = np.array(list(map(lambda x: x in videos, argm))) | |||||
| result = np.nonzero(result) | |||||
| return result | |||||
| def qb_norm(train_test, test_test, args): | |||||
| k = args.get('k', 1) | |||||
| beta = args.get('beta', 20) | |||||
| retrieved_videos = get_retrieved_videos(train_test, k) | |||||
| test_test_normalized = test_test | |||||
| train_test = np.exp(train_test * beta) | |||||
| test_test = np.exp(test_test * beta) | |||||
| normalizing_sum = np.sum(train_test, axis=0) | |||||
| index_for_normalizing = get_index_to_normalize(test_test, retrieved_videos) | |||||
| test_test_normalized[index_for_normalizing, :] = \ | |||||
| np.divide(test_test[index_for_normalizing, :], normalizing_sum) | |||||
| return test_test_normalized | |||||
| @@ -0,0 +1,508 @@ | |||||
| import os | |||||
| import platform | |||||
| from collections import OrderedDict | |||||
| from types import SimpleNamespace | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |||||
| from modelscope.models.multi_modal.mmr.models.module_clip import ( | |||||
| _PT_NAME, CLIP, QuickGELU, convert_weights) | |||||
| from modelscope.models.multi_modal.mmr.models.module_cross import \ | |||||
| Transformer as TransformerClip | |||||
| from modelscope.models.multi_modal.mmr.models.until_module import (AllGather, | |||||
| CrossEn, | |||||
| LayerNorm) | |||||
| from modelscope.utils.logger import get_logger | |||||
| allgather = AllGather.apply | |||||
| logger = get_logger() | |||||
| __all__ = ['CLIP4Clip'] | |||||
| class CLIP4Clip(nn.Module): | |||||
| def __init__(self, config): | |||||
| super(CLIP4Clip, self).__init__() | |||||
| self.config = config | |||||
| self.loose_type = config['loose_type'] | |||||
| self.sim_header = config['sim_header'] | |||||
| if self.sim_header in [ | |||||
| 'tightTransf', 'tightFc1', 'tightFc2', 'tightFc3', 'tightFc4', | |||||
| 'tightMean', 'tightFc5' | |||||
| ]: | |||||
| assert self.loose_type is False | |||||
| backbone = config['pretrained_clip_name'] | |||||
| # fix backbone without downlond | |||||
| model_path = '{}/ViT-B-16.pt'.format(config['model_dir']) | |||||
| if not os.path.exists(model_path): | |||||
| logger.info('no model loaded!!!') | |||||
| try: | |||||
| # loading JIT archive | |||||
| model = torch.jit.load(model_path, map_location='cpu').eval() | |||||
| state_dict = model.state_dict() | |||||
| except RuntimeError: | |||||
| state_dict = torch.load(model_path, map_location='cpu') | |||||
| vision_width = state_dict['visual.conv1.weight'].shape[0] | |||||
| vision_layers = len([ | |||||
| k for k in state_dict.keys() | |||||
| if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') | |||||
| ]) | |||||
| vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] | |||||
| grid_size = round( | |||||
| (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) | |||||
| image_resolution = vision_patch_size * grid_size | |||||
| embed_dim = state_dict['text_projection'].shape[1] | |||||
| context_length = state_dict['positional_embedding'].shape[0] | |||||
| vocab_size = state_dict['token_embedding.weight'].shape[0] | |||||
| transformer_width = state_dict['ln_final.weight'].shape[0] | |||||
| transformer_heads = transformer_width // 64 | |||||
| transformer_layers = len( | |||||
| set( | |||||
| k.split('.')[2] for k in state_dict | |||||
| if k.startswith('transformer.resblocks'))) | |||||
| cut_top_layer = 0 | |||||
| self.clip = CLIP( | |||||
| embed_dim, | |||||
| image_resolution, | |||||
| vision_layers - cut_top_layer, | |||||
| vision_width, | |||||
| vision_patch_size, | |||||
| context_length, | |||||
| vocab_size, | |||||
| transformer_width, | |||||
| transformer_heads, | |||||
| transformer_layers - cut_top_layer, | |||||
| linear_patch=config['linear_patch'], | |||||
| use_gc=config['use_gc']).float() | |||||
| if (platform.system() != 'Darwin'): | |||||
| convert_weights(self.clip) # fp16 | |||||
| if backbone in ['ViT-B/32', 'ViT-B/16']: | |||||
| cross_config = SimpleNamespace(**{ | |||||
| 'hidden_size': 512, | |||||
| 'max_position_embeddings': 128, | |||||
| }) | |||||
| elif backbone in ['ViT-L/14', 'ViT-B/14-336px']: | |||||
| cross_config = SimpleNamespace(**{ | |||||
| 'hidden_size': 768, | |||||
| 'max_position_embeddings': 128, | |||||
| }) | |||||
| else: | |||||
| raise ValueError | |||||
| cross_config.max_position_embeddings = context_length | |||||
| self.cross_config = cross_config | |||||
| self.text_weight_fc = nn.Sequential( | |||||
| nn.Linear(transformer_width, transformer_width), | |||||
| nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) | |||||
| self.video_weight_fc = nn.Sequential( | |||||
| nn.Linear(transformer_width, transformer_width), | |||||
| nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) | |||||
| if self.loose_type is False: | |||||
| raise NotImplementedError | |||||
| if self.sim_header in ['seqLSTM', 'seqTransf', 'tightFc1']: | |||||
| self.frame_position_embeddings = nn.Embedding( | |||||
| cross_config.max_position_embeddings, cross_config.hidden_size) | |||||
| if self.sim_header in ['seqTransf', 'tightFc1']: | |||||
| self.transformerClip = TransformerClip( | |||||
| width=transformer_width, | |||||
| layers=config['cross_num_hidden_layers'], | |||||
| heads=transformer_heads, | |||||
| ) | |||||
| if self.sim_header == 'seqLSTM': | |||||
| self.lstm_visual = nn.LSTM( | |||||
| input_size=cross_config.hidden_size, | |||||
| hidden_size=cross_config.hidden_size, | |||||
| batch_first=True, | |||||
| bidirectional=False, | |||||
| num_layers=1) | |||||
| self.loss_fct = CrossEn(config) | |||||
| self.apply(self.init_weights) | |||||
| self.clip.load_state_dict(state_dict, strict=False) | |||||
| # ===> Initialization trick [HARD CODE] | |||||
| if backbone not in _PT_NAME: | |||||
| raise NotImplementedError | |||||
| # reload | |||||
| else: | |||||
| if config['linear_patch'] == '3d': | |||||
| raise NotImplementedError | |||||
| new_state_dict = OrderedDict() | |||||
| if self.sim_header == 'tightTransf': | |||||
| raise NotImplementedError | |||||
| if self.sim_header in ['seqLSTM', 'seqTransf', 'seqFc1']: | |||||
| contain_frame_position = False | |||||
| for key in state_dict.keys(): | |||||
| if key.find('frame_position_embeddings') > -1: | |||||
| contain_frame_position = True | |||||
| break | |||||
| if contain_frame_position is False: | |||||
| for key, val in state_dict.items(): | |||||
| if key == 'positional_embedding': | |||||
| new_state_dict[ | |||||
| 'frame_position_embeddings.weight'] = val.clone() | |||||
| continue | |||||
| if self.sim_header in [ | |||||
| 'seqTransf', 'seqFc1' | |||||
| ] and key.find('transformer.resblocks') == 0: | |||||
| num_layer = int(key.split('.')[2]) | |||||
| # cut from beginning | |||||
| if num_layer < config['cross_num_hidden_layers']: | |||||
| new_state_dict[key.replace( | |||||
| 'transformer.', | |||||
| 'transformerClip.')] = val.clone() | |||||
| continue | |||||
| # <=== End of initialization trick | |||||
| self.load_state_dict( | |||||
| new_state_dict, strict=False | |||||
| ) # only update new state (seqTransf/seqLSTM/tightTransf) | |||||
| if self.sim_header == 'tightFc5': | |||||
| raise ValueError | |||||
| def forward(self, | |||||
| input_ids, | |||||
| token_type_ids, | |||||
| attention_mask, | |||||
| video, | |||||
| video_mask=None): | |||||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||||
| # B x T x 3 x H x W - > (B x T) x 3 x H x W | |||||
| video = torch.as_tensor(video).float() | |||||
| if len(video.shape) == 6: # image | |||||
| b, bs, ts, channel, h, w = video.shape | |||||
| b = b * bs | |||||
| else: # video | |||||
| b, ts, channel, h, w = video.shape | |||||
| video = video.view(b * ts, channel, h, w) | |||||
| sequence_output, visual_output = self.get_sequence_visual_output( | |||||
| input_ids, | |||||
| token_type_ids, | |||||
| attention_mask, | |||||
| video, | |||||
| video_mask, | |||||
| shaped=True) | |||||
| if self.training: | |||||
| loss = 0. | |||||
| sim_matrix1, sim_matrix2, barlow_loss = self.get_similarity_logits( | |||||
| sequence_output, | |||||
| visual_output, | |||||
| attention_mask, | |||||
| video_mask, | |||||
| shaped=True, | |||||
| loose_type=self.loose_type) | |||||
| sim_loss = (self.loss_fct(sim_matrix1) | |||||
| + self.loss_fct(sim_matrix2)) / 2 | |||||
| loss += sim_loss + barlow_loss * self.config.cdcr_lambda | |||||
| return loss | |||||
| else: | |||||
| return None | |||||
| def get_sequence_output(self, | |||||
| input_ids, | |||||
| token_type_ids, | |||||
| attention_mask, | |||||
| shaped=False): | |||||
| if shaped is False: | |||||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||||
| bs_pair = input_ids.size(0) | |||||
| sequence_hidden = self.clip.encode_text( | |||||
| input_ids, return_hidden=True, prompt=None)[1].float() | |||||
| sequence_hidden = sequence_hidden.view(bs_pair, -1, | |||||
| sequence_hidden.size(-1)) | |||||
| return sequence_hidden | |||||
| def get_visual_output(self, video, video_mask, shaped=False): | |||||
| if shaped is False: | |||||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||||
| video = torch.as_tensor(video).float() | |||||
| b, ts, channel, h, w = video.shape | |||||
| video = video.view(b * ts, channel, h, w) | |||||
| bs_pair = video_mask.size(0) | |||||
| visual_hidden = self.clip.encode_image(video).float() | |||||
| visual_hidden = visual_hidden.float().view(bs_pair, -1, | |||||
| visual_hidden.size(-1)) | |||||
| return visual_hidden | |||||
| def get_sequence_visual_output(self, | |||||
| input_ids, | |||||
| token_type_ids, | |||||
| attention_mask, | |||||
| video, | |||||
| video_mask, | |||||
| shaped=False): | |||||
| if shaped is False: | |||||
| input_ids = input_ids.view(-1, input_ids.shape[-1]) | |||||
| token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) | |||||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||||
| video = torch.as_tensor(video).float() | |||||
| if len(video.shape) == 6: # image | |||||
| b, bs, ts, channel, h, w = video.shape | |||||
| b = b * bs | |||||
| else: # video | |||||
| b, ts, channel, h, w = video.shape | |||||
| video = video.view(b * ts, channel, h, w) | |||||
| sequence_output = self.get_sequence_output( | |||||
| input_ids, token_type_ids, attention_mask, shaped=True) | |||||
| visual_output = self.get_visual_output(video, video_mask, shaped=True) | |||||
| return sequence_output, visual_output | |||||
| def agg_video_feat(self, visual_output, video_mask, sim_header='meanP'): | |||||
| if self.config.max_sum == 0: | |||||
| raise ValueError | |||||
| if sim_header == 'meanP': | |||||
| # Default: Parameter-free type | |||||
| pass | |||||
| elif sim_header == 'seqLSTM': | |||||
| # Sequential type: LSTM | |||||
| visual_output_original = visual_output | |||||
| visual_output = pack_padded_sequence( | |||||
| visual_output, | |||||
| torch.sum(video_mask, dim=-1).cpu(), | |||||
| batch_first=True, | |||||
| enforce_sorted=False) | |||||
| visual_output, _ = self.lstm_visual(visual_output) | |||||
| if self.training: | |||||
| self.lstm_visual.flatten_parameters() | |||||
| visual_output, _ = pad_packed_sequence( | |||||
| visual_output, batch_first=True) | |||||
| visual_output = torch.cat( | |||||
| (visual_output, visual_output_original[:, | |||||
| visual_output.size(1):, | |||||
| ...].contiguous()), | |||||
| dim=1) | |||||
| visual_output = visual_output + visual_output_original | |||||
| elif sim_header == 'seqTransf': | |||||
| # Sequential type: Transformer Encoder | |||||
| visual_output_original = visual_output | |||||
| seq_length = visual_output.size(1) | |||||
| position_ids = torch.arange( | |||||
| seq_length, dtype=torch.long, device=visual_output.device) | |||||
| position_ids = position_ids.unsqueeze(0).expand( | |||||
| visual_output.size(0), -1) | |||||
| frame_position_embeddings = self.frame_position_embeddings( | |||||
| position_ids) | |||||
| visual_output = visual_output + frame_position_embeddings | |||||
| extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 | |||||
| extended_video_mask = extended_video_mask.expand( | |||||
| -1, video_mask.size(1), -1) | |||||
| visual_output = visual_output.permute(1, 0, 2) # NLD -> LND | |||||
| visual_output = self.transformerClip(visual_output, | |||||
| extended_video_mask) | |||||
| visual_output = visual_output.permute(1, 0, 2) # LND -> NLD | |||||
| visual_output = visual_output + visual_output_original | |||||
| return visual_output | |||||
| def wti_interaction(self, text_feat, video_feat, text_mask, video_mask): | |||||
| text_weight = self.text_weight_fc(text_feat).squeeze( | |||||
| 2) # B x N_t x D -> B x N_t | |||||
| text_weight.masked_fill_( | |||||
| torch.tensor((1 - text_mask), dtype=torch.bool), float('-inf')) | |||||
| text_weight = torch.softmax(text_weight, dim=-1) # B x N_t | |||||
| video_weight = self.video_weight_fc(video_feat).squeeze( | |||||
| 2) # B x N_v x D -> B x N_v | |||||
| video_weight.masked_fill_( | |||||
| torch.tensor((1 - video_mask), dtype=torch.bool), float('-inf')) | |||||
| video_weight = torch.softmax(video_weight, dim=-1) # B x N_v | |||||
| text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) | |||||
| video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) | |||||
| retrieve_logits = torch.einsum('atd,bvd->abtv', | |||||
| [text_feat, video_feat]) | |||||
| retrieve_logits = torch.einsum('abtv,at->abtv', | |||||
| [retrieve_logits, text_mask]) | |||||
| retrieve_logits = torch.einsum('abtv,bv->abtv', | |||||
| [retrieve_logits, video_mask]) | |||||
| t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt | |||||
| t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) | |||||
| v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv | |||||
| v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) | |||||
| retrieve_logits = (t2v_logits + v2t_logits) / 2.0 | |||||
| if self.training: | |||||
| logit_scale = self.clip.logit_scale.exp() | |||||
| retrieve_logits = logit_scale * retrieve_logits | |||||
| # selecet max | |||||
| max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), | |||||
| torch.arange(max_idx1.shape[1])] | |||||
| max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), | |||||
| torch.arange(max_idx2.shape[1])] | |||||
| max_t_feat = text_feat[torch.arange(max_idx2.shape[0]). | |||||
| repeat_interleave(max_idx2.shape[1]), | |||||
| max_idx2.flatten()].squeeze(1) | |||||
| max_v_feat = video_feat[torch.arange(max_idx1.shape[0]). | |||||
| repeat_interleave(max_idx1.shape[1]), | |||||
| max_idx1.flatten()].squeeze(1) | |||||
| t_feat = text_feat.reshape(-1, text_feat.shape[-1]) | |||||
| t_mask = text_mask.flatten().type(torch.bool) | |||||
| v_feat = video_feat.reshape(-1, video_feat.shape[-1]) | |||||
| v_mask = video_mask.flatten().type(torch.bool) | |||||
| t_feat = t_feat[t_mask] | |||||
| v_feat = v_feat[v_mask] | |||||
| max_t_feat = max_t_feat[v_mask] | |||||
| max_v_feat = max_v_feat[t_mask] | |||||
| text_weight = text_weight.flatten()[t_mask] | |||||
| video_weight = video_weight.flatten()[v_mask] | |||||
| z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (BxN_t)xD | |||||
| z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std( | |||||
| 0) # (BxN_t)xD | |||||
| x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (BxN_v)xD | |||||
| x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std( | |||||
| 0) # (BxN_v)xD | |||||
| # cross-correlation matrix | |||||
| N, D = z_a_norm.shape | |||||
| B = text_feat.shape[0] | |||||
| c1 = torch.einsum('acd,a->cd', | |||||
| torch.einsum('ac,ad->acd', z_a_norm, z_b_norm), | |||||
| text_weight) / B # DxD | |||||
| c2 = torch.einsum('acd,a->cd', | |||||
| torch.einsum('ac,ad->acd', x_a_norm, x_b_norm), | |||||
| video_weight) / B # DxD | |||||
| c = (c1 + c2) / 2.0 | |||||
| # loss | |||||
| on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() | |||||
| off_diag = c.flatten()[1:].view(D - 1, D + 1)[:, :-1].pow_(2).sum() | |||||
| cdcr_loss = ( | |||||
| on_diag * self.config.cdcr_alpha1 | |||||
| + off_diag * self.config.cdcr_alpha2) | |||||
| return retrieve_logits, retrieve_logits.T, cdcr_loss | |||||
| else: | |||||
| return retrieve_logits, retrieve_logits.T | |||||
| def _loose_similarity(self, | |||||
| sequence_output, | |||||
| visual_output, | |||||
| attention_mask, | |||||
| video_mask, | |||||
| sim_header='seqTransf'): | |||||
| sequence_output, visual_output = sequence_output.contiguous( | |||||
| ), visual_output.contiguous() | |||||
| visual_output = self.agg_video_feat(visual_output, video_mask, | |||||
| sim_header) | |||||
| if self.training: # batch merge here | |||||
| visual_output = allgather(visual_output, self.config) | |||||
| attention_mask = allgather(attention_mask, self.config) | |||||
| video_mask = allgather(video_mask, self.config) | |||||
| sequence_output = allgather(sequence_output, self.config) | |||||
| torch.distributed.barrier() # force sync | |||||
| return self.wti_interaction(sequence_output, visual_output, | |||||
| attention_mask, video_mask) | |||||
| def get_similarity_logits(self, | |||||
| sequence_output, | |||||
| visual_output, | |||||
| attention_mask, | |||||
| video_mask, | |||||
| shaped=False, | |||||
| loose_type=False): | |||||
| if shaped is False: | |||||
| attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) | |||||
| video_mask = video_mask.view(-1, video_mask.shape[-1]) | |||||
| if loose_type: | |||||
| assert self.sim_header in ['meanP', 'seqLSTM', 'seqTransf'] | |||||
| if self.training: | |||||
| retrieve_logits1, retrieve_logits2, barlow_loss = self._loose_similarity( | |||||
| sequence_output, | |||||
| visual_output, | |||||
| attention_mask, | |||||
| video_mask, | |||||
| sim_header=self.sim_header) | |||||
| return retrieve_logits1, retrieve_logits2, barlow_loss | |||||
| else: | |||||
| retrieve_logits1, retrieve_logits2 = self._loose_similarity( | |||||
| sequence_output, | |||||
| visual_output, | |||||
| attention_mask, | |||||
| video_mask, | |||||
| sim_header=self.sim_header) | |||||
| return retrieve_logits1, retrieve_logits2 | |||||
| else: | |||||
| raise NotImplementedError | |||||
| @property | |||||
| def dtype(self): | |||||
| """ | |||||
| :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
| """ | |||||
| try: | |||||
| return next(self.parameters()).dtype | |||||
| except StopIteration: | |||||
| # For nn.DataParallel compatibility in PyTorch 1.5 | |||||
| def find_tensor_attributes(module: nn.Module): | |||||
| tuples = [(k, v) for k, v in module.__dict__.items() | |||||
| if torch.is_tensor(v)] | |||||
| return tuples | |||||
| gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
| first_tuple = next(gen) | |||||
| return first_tuple[1].dtype | |||||
| def init_weights(self, module): | |||||
| """ Initialize the weights. | |||||
| """ | |||||
| if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||||
| module.weight.data.normal_(mean=0.0, std=0.02) | |||||
| elif isinstance(module, LayerNorm): | |||||
| if 'beta' in dir(module) and 'gamma' in dir(module): | |||||
| module.beta.data.zero_() | |||||
| module.gamma.data.fill_(1.0) | |||||
| else: | |||||
| module.bias.data.zero_() | |||||
| module.weight.data.fill_(1.0) | |||||
| if isinstance(module, nn.Linear) and module.bias is not None: | |||||
| module.bias.data.zero_() | |||||
| @@ -0,0 +1,526 @@ | |||||
| # Part of the implementation is borrowed and modified from The OpenAI CLIP project. | |||||
| import hashlib | |||||
| import os | |||||
| import urllib | |||||
| import warnings | |||||
| from collections import OrderedDict | |||||
| from typing import Tuple, Union | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint as checkpoint | |||||
| from torch import nn | |||||
| from tqdm import tqdm | |||||
| _MODELS = {} | |||||
| _PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'} | |||||
| def available_models(): | |||||
| """Returns the names of available CLIP models""" | |||||
| return list(_MODELS.keys()) | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1): | |||||
| super(Bottleneck, self).__init__() | |||||
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(planes) | |||||
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(planes) | |||||
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||||
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = None | |||||
| self.stride = stride | |||||
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||||
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||||
| self.downsample = nn.Sequential( | |||||
| OrderedDict([('-1', nn.AvgPool2d(stride)), | |||||
| ('0', | |||||
| nn.Conv2d( | |||||
| inplanes, | |||||
| planes * self.expansion, | |||||
| 1, | |||||
| stride=1, | |||||
| bias=False)), | |||||
| ('1', nn.BatchNorm2d(planes * self.expansion))])) | |||||
| def forward(self, x: torch.Tensor): | |||||
| identity = x | |||||
| out = self.relu(self.bn1(self.conv1(x))) | |||||
| out = self.relu(self.bn2(self.conv2(out))) | |||||
| out = self.avgpool(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class AttentionPool2d(nn.Module): | |||||
| def __init__(self, | |||||
| spacial_dim: int, | |||||
| embed_dim: int, | |||||
| num_heads: int, | |||||
| output_dim: int = None): | |||||
| super(AttentionPool2d, self).__init__() | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||||
| self.num_heads = num_heads | |||||
| def forward(self, x): | |||||
| x = x.reshape(x.shape[0], x.shape[1], | |||||
| x.shape[2] * x.shape[3]).permute(2, 0, | |||||
| 1) # NCHW -> (HW)NC | |||||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||||
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||||
| x, _ = F.multi_head_attention_forward( | |||||
| query=x, | |||||
| key=x, | |||||
| value=x, | |||||
| embed_dim_to_check=x.shape[-1], | |||||
| num_heads=self.num_heads, | |||||
| q_proj_weight=self.q_proj.weight, | |||||
| k_proj_weight=self.k_proj.weight, | |||||
| v_proj_weight=self.v_proj.weight, | |||||
| in_proj_weight=None, | |||||
| in_proj_bias=torch.cat( | |||||
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||||
| bias_k=None, | |||||
| bias_v=None, | |||||
| add_zero_attn=False, | |||||
| dropout_p=0, | |||||
| out_proj_weight=self.c_proj.weight, | |||||
| out_proj_bias=self.c_proj.bias, | |||||
| use_separate_proj_weight=True, | |||||
| training=self.training, | |||||
| need_weights=False) | |||||
| return x[0] | |||||
| class ModifiedResNet(nn.Module): | |||||
| """ | |||||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||||
| - The final pooling layer is a QKV attention instead of an average pool | |||||
| """ | |||||
| def __init__(self, | |||||
| layers, | |||||
| output_dim, | |||||
| heads, | |||||
| input_resolution=224, | |||||
| width=64): | |||||
| super(ModifiedResNet, self).__init__() | |||||
| self.output_dim = output_dim | |||||
| self.input_resolution = input_resolution | |||||
| # the 3-layer stem | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||||
| self.conv2 = nn.Conv2d( | |||||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||||
| self.conv3 = nn.Conv2d( | |||||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(width) | |||||
| self.avgpool = nn.AvgPool2d(2) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| # residual layers | |||||
| self._inplanes = width # this is a *mutable* variable used during construction | |||||
| self.layer1 = self._make_layer(width, layers[0]) | |||||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||||
| embed_dim = width * 32 # the ResNet feature dimension | |||||
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||||
| heads, output_dim) | |||||
| def _make_layer(self, planes, blocks, stride=1): | |||||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||||
| self._inplanes = planes * Bottleneck.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append(Bottleneck(self._inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| def stem(x): | |||||
| for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||||
| (self.conv3, self.bn3)]: | |||||
| x = self.relu(bn(conv(x))) | |||||
| x = self.avgpool(x) | |||||
| return x | |||||
| x = x.type(self.conv1.weight.dtype) | |||||
| x = stem(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.attnpool(x) | |||||
| return x | |||||
| class LayerNorm(nn.LayerNorm): | |||||
| """Subclass torch's LayerNorm to handle fp16.""" | |||||
| def forward(self, x: torch.Tensor): | |||||
| orig_type = x.dtype | |||||
| ret = super().forward(x.type(torch.float32)) | |||||
| return ret.type(orig_type) | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x: torch.Tensor): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class ResidualAttentionBlock(nn.Module): | |||||
| def __init__(self, d_model: int, n_head: int, attn_mask=None): | |||||
| super(ResidualAttentionBlock, self).__init__() | |||||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||||
| self.ln_1 = LayerNorm(d_model) | |||||
| self.mlp = nn.Sequential( | |||||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||||
| ('gelu', QuickGELU()), | |||||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||||
| self.ln_2 = LayerNorm(d_model) | |||||
| self.attn_mask = attn_mask | |||||
| def attention(self, x: torch.Tensor): | |||||
| attn_mask_ = self.attn_mask | |||||
| if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): | |||||
| attn_mask_ = self.attn_mask(x.size(0)) # LND | |||||
| attn_mask_ = attn_mask_.to( | |||||
| dtype=x.dtype, device=x.device) if attn_mask_ is not None else None | |||||
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |||||
| def forward(self, x): | |||||
| x = x + self.attention(self.ln_1(x)) | |||||
| x = x + self.mlp(self.ln_2(x)) | |||||
| return x | |||||
| class Transformer(nn.Module): | |||||
| def __init__(self, | |||||
| width: int, | |||||
| layers: int, | |||||
| heads: int, | |||||
| attn_mask=None, | |||||
| use_gc=0): | |||||
| super(Transformer, self).__init__() | |||||
| self.width = width | |||||
| self.layers = layers | |||||
| self.resblocks = nn.Sequential(*[ | |||||
| ResidualAttentionBlock(width, heads, attn_mask) | |||||
| for _ in range(layers) | |||||
| ]) | |||||
| self.use_gc = use_gc | |||||
| def forward(self, x: torch.Tensor): | |||||
| if self.use_gc > 0: | |||||
| for blk in self.resblocks: | |||||
| x = checkpoint.checkpoint(blk, x) | |||||
| return x | |||||
| else: | |||||
| return self.resblocks(x) | |||||
| class VisualTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| input_resolution: int, | |||||
| patch_size: int, | |||||
| width: int, | |||||
| layers: int, | |||||
| heads: int, | |||||
| output_dim: int, | |||||
| linear_patch: str = '2d', | |||||
| use_gc: int = 0): | |||||
| super(VisualTransformer, self).__init__() | |||||
| self.input_resolution = input_resolution | |||||
| self.output_dim = output_dim | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_channels=3, | |||||
| out_channels=width, | |||||
| kernel_size=patch_size, | |||||
| stride=patch_size, | |||||
| bias=False) | |||||
| scale = width**-0.5 | |||||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||||
| (input_resolution // patch_size)**2 + 1, width)) | |||||
| self.ln_pre = LayerNorm(width) | |||||
| self.transformer = Transformer(width, layers, heads, use_gc=use_gc) | |||||
| self.ln_post = LayerNorm(width) | |||||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||||
| # For 3D | |||||
| assert linear_patch in ['2d', '3d'] | |||||
| self.linear_patch = linear_patch | |||||
| if self.linear_patch == '3d': | |||||
| self.conv2 = nn.Conv3d( | |||||
| in_channels=3, | |||||
| out_channels=width, | |||||
| kernel_size=(3, patch_size, patch_size), | |||||
| stride=(1, patch_size, patch_size), | |||||
| padding=(1, 0, 0), | |||||
| bias=False) | |||||
| def forward(self, x: torch.Tensor, video_frame=-1): | |||||
| if self.linear_patch == '3d': | |||||
| assert video_frame != -1 | |||||
| x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], | |||||
| x.shape[-1]) | |||||
| x_3d = x_3d.permute(0, 2, 1, 3, 4) | |||||
| x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] | |||||
| x_3d = x_3d.permute(0, 2, 1, 3, | |||||
| 4) # shape = [*, frame, width, grid, grid] | |||||
| x = x_3d.reshape( | |||||
| -1, x_3d.shape[-3], x_3d.shape[-2], | |||||
| x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] | |||||
| else: | |||||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||||
| x = x.reshape(x.shape[0], x.shape[1], | |||||
| -1) # shape = [*, width, grid ** 2] | |||||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||||
| _x = self.class_embedding.to(x.dtype) + torch.zeros( | |||||
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||||
| x = torch.cat([_x, x], dim=1) | |||||
| x = x + self.positional_embedding.to(x.dtype) | |||||
| x = self.ln_pre(x) | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) # LND -> NLD | |||||
| return x | |||||
| class CLIP(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| embed_dim: int, | |||||
| # vision | |||||
| image_resolution: int, | |||||
| vision_layers: Union[Tuple[int, int, int, int], int], | |||||
| vision_width: int, | |||||
| vision_patch_size: int, | |||||
| # text | |||||
| context_length: int, | |||||
| vocab_size: int, | |||||
| transformer_width: int, | |||||
| transformer_heads: int, | |||||
| transformer_layers: int, | |||||
| # vision linear of patch | |||||
| linear_patch: str = '2d', | |||||
| use_gc: int = 0): | |||||
| super(CLIP, self).__init__() | |||||
| self.context_length = context_length | |||||
| if isinstance(vision_layers, (tuple, list)): | |||||
| vision_heads = vision_width * 32 // 64 | |||||
| self.visual = ModifiedResNet( | |||||
| layers=vision_layers, | |||||
| output_dim=embed_dim, | |||||
| heads=vision_heads, | |||||
| input_resolution=image_resolution, | |||||
| width=vision_width) | |||||
| else: | |||||
| vision_heads = vision_width // 64 | |||||
| self.visual = VisualTransformer( | |||||
| input_resolution=image_resolution, | |||||
| patch_size=vision_patch_size, | |||||
| width=vision_width, | |||||
| layers=vision_layers, | |||||
| heads=vision_heads, | |||||
| output_dim=embed_dim, | |||||
| linear_patch=linear_patch, | |||||
| use_gc=use_gc) | |||||
| self.transformer = Transformer( | |||||
| width=transformer_width, | |||||
| layers=transformer_layers, | |||||
| heads=transformer_heads, | |||||
| attn_mask=self.build_attention_mask) | |||||
| self.vocab_size = vocab_size | |||||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.empty(self.context_length, transformer_width)) | |||||
| self.ln_final = LayerNorm(transformer_width) | |||||
| self.text_projection = nn.Parameter( | |||||
| torch.empty(transformer_width, embed_dim)) | |||||
| self.logit_scale = nn.Parameter(torch.ones([])) | |||||
| self.initialize_parameters() | |||||
| def initialize_parameters(self): | |||||
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |||||
| nn.init.normal_(self.positional_embedding, std=0.01) | |||||
| if isinstance(self.visual, ModifiedResNet): | |||||
| if self.visual.attnpool is not None: | |||||
| std = self.visual.attnpool.c_proj.in_features**-0.5 | |||||
| nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||||
| for resnet_block in [ | |||||
| self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||||
| self.visual.layer4 | |||||
| ]: | |||||
| for name, param in resnet_block.named_parameters(): | |||||
| if name.endswith('bn3.weight'): | |||||
| nn.init.zeros_(param) | |||||
| proj_std = (self.transformer.width**-0.5) * ( | |||||
| (2 * self.transformer.layers)**-0.5) | |||||
| attn_std = self.transformer.width**-0.5 | |||||
| fc_std = (2 * self.transformer.width)**-0.5 | |||||
| for block in self.transformer.resblocks: | |||||
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||||
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||||
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||||
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||||
| if self.text_projection is not None: | |||||
| nn.init.normal_( | |||||
| self.text_projection, std=self.transformer.width**-0.5) | |||||
| def build_attention_mask(self, context_length): | |||||
| # lazily create causal attention mask, with full attention between the vision tokens | |||||
| # pytorch uses additive attention mask; fill with -inf | |||||
| mask = torch.zeros(context_length, context_length) | |||||
| mask.fill_(float('-inf')) | |||||
| mask.triu_(1) # zero out the lower diagonal | |||||
| return mask | |||||
| @property | |||||
| def dtype(self): | |||||
| return self.visual.conv1.weight.dtype | |||||
| def encode_image(self, image, return_hidden=False): | |||||
| hidden = self.visual(image.type(self.dtype)) | |||||
| hidden = self.visual.ln_post(hidden) @ self.visual.proj | |||||
| x = hidden[:, 0, :] | |||||
| if return_hidden: | |||||
| return x, hidden | |||||
| return x | |||||
| def encode_text(self, text, return_hidden=False, prompt=None): | |||||
| x = self.token_embedding(text).type( | |||||
| self.dtype) # [batch_size, n_ctx, d_model] | |||||
| if prompt: | |||||
| x = prompt(x) | |||||
| pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) | |||||
| x = x + pos_emd | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) # LND -> NLD | |||||
| hidden = self.ln_final(x).type(self.dtype) @ self.text_projection | |||||
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |||||
| x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] | |||||
| if return_hidden: | |||||
| return x, hidden | |||||
| return x | |||||
| def forward(self, image, text): | |||||
| image_features = self.encode_image(image) | |||||
| text_features = self.encode_text(text) | |||||
| # normalized features | |||||
| image_features = image_features / image_features.norm( | |||||
| dim=-1, keepdim=True) | |||||
| text_features = text_features / text_features.norm( | |||||
| dim=-1, keepdim=True) | |||||
| # cosine similarity as logits | |||||
| logit_scale = self.logit_scale.exp() | |||||
| logits_per_image = logit_scale * image_features @ text_features.t() | |||||
| logits_per_text = logit_scale * text_features @ image_features.t() | |||||
| return logits_per_image, logits_per_text | |||||
| def convert_weights(model: nn.Module): | |||||
| """Convert applicable model parameters to fp16""" | |||||
| def _convert_weights_to_fp16(lay): | |||||
| # l = lay | |||||
| if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |||||
| lay.weight.data = lay.weight.data.half() | |||||
| if lay.bias is not None: | |||||
| lay.bias.data = lay.bias.data.half() | |||||
| if isinstance(lay, nn.MultiheadAttention): | |||||
| for attr in [ | |||||
| *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], | |||||
| 'in_proj_bias', 'bias_k', 'bias_v' | |||||
| ]: | |||||
| tensor = getattr(lay, attr) | |||||
| if tensor is not None: | |||||
| tensor.data = tensor.data.half() | |||||
| for name in ['text_projection', 'proj']: | |||||
| if hasattr(lay, name): | |||||
| attr = getattr(lay, name) | |||||
| if attr is not None: | |||||
| attr.data = attr.data.half() | |||||
| model.apply(_convert_weights_to_fp16) | |||||
| @@ -0,0 +1,100 @@ | |||||
| from __future__ import absolute_import, division, print_function | |||||
| import logging | |||||
| from collections import OrderedDict | |||||
| import json | |||||
| import torch | |||||
| from torch import nn | |||||
| from .until_module import ACT2FN, LayerNorm | |||||
| logger = logging.getLogger(__name__) | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x: torch.Tensor): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class ResidualAttentionBlock(nn.Module): | |||||
| def __init__(self, d_model: int, n_head: int): | |||||
| super().__init__() | |||||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||||
| self.ln_1 = LayerNorm(d_model) | |||||
| self.mlp = nn.Sequential( | |||||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||||
| ('gelu', QuickGELU()), | |||||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||||
| self.ln_2 = LayerNorm(d_model) | |||||
| self.n_head = n_head | |||||
| def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): | |||||
| attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) | |||||
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |||||
| def forward(self, para_tuple: tuple): | |||||
| # x: torch.Tensor, attn_mask: torch.Tensor | |||||
| x, attn_mask = para_tuple | |||||
| x = x + self.attention(self.ln_1(x), attn_mask) | |||||
| x = x + self.mlp(self.ln_2(x)) | |||||
| return (x, attn_mask) | |||||
| class Transformer(nn.Module): | |||||
| def __init__(self, width: int, layers: int, heads: int): | |||||
| super().__init__() | |||||
| self.width = width | |||||
| self.layers = layers | |||||
| self.resblocks = nn.Sequential( | |||||
| *[ResidualAttentionBlock(width, heads) for _ in range(layers)]) | |||||
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |||||
| return self.resblocks((x, attn_mask))[0] | |||||
| class CrossEmbeddings(nn.Module): | |||||
| """Construct the embeddings from word, position and token_type embeddings. | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(CrossEmbeddings, self).__init__() | |||||
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||||
| config.hidden_size) | |||||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
| def forward(self, concat_embeddings, concat_type=None): | |||||
| _, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) | |||||
| position_ids = torch.arange( | |||||
| seq_length, dtype=torch.long, device=concat_embeddings.device) | |||||
| position_ids = position_ids.unsqueeze(0).expand( | |||||
| concat_embeddings.size(0), -1) | |||||
| position_embeddings = self.position_embeddings(position_ids) | |||||
| embeddings = concat_embeddings + position_embeddings # + token_type_embeddings | |||||
| embeddings = self.dropout(embeddings) | |||||
| return embeddings | |||||
| class CrossPooler(nn.Module): | |||||
| def __init__(self, config): | |||||
| super(CrossPooler, self).__init__() | |||||
| self.ln_pool = LayerNorm(config.hidden_size) | |||||
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||||
| self.activation = QuickGELU() | |||||
| def forward(self, hidden_states, hidden_mask): | |||||
| # We "pool" the model by simply taking the hidden state corresponding | |||||
| # to the first token. | |||||
| hidden_states = self.ln_pool(hidden_states) | |||||
| pooled_output = hidden_states[:, 0] | |||||
| pooled_output = self.dense(pooled_output) | |||||
| pooled_output = self.activation(pooled_output) | |||||
| return pooled_output | |||||
| @@ -0,0 +1,158 @@ | |||||
| import gzip | |||||
| import html | |||||
| import os | |||||
| from functools import lru_cache | |||||
| import ftfy | |||||
| import regex as re | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| def basic_clean(text): | |||||
| text = ftfy.fix_text(text) | |||||
| text = html.unescape(html.unescape(text)) | |||||
| return text.strip() | |||||
| def whitespace_clean(text): | |||||
| text = re.sub(r'\s+', ' ', text) | |||||
| text = text.strip() | |||||
| return text | |||||
| class SimpleTokenizer(object): | |||||
| def __init__(self, model_dir): | |||||
| bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir) | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||||
| merges = merges[1:49152 - 256 - 2 + 1] | |||||
| merges = [tuple(merge.split()) for merge in merges] | |||||
| vocab = list(bytes_to_unicode().values()) | |||||
| vocab = vocab + [v + '</w>' for v in vocab] | |||||
| for merge in merges: | |||||
| vocab.append(''.join(merge)) | |||||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||||
| self.cache = { | |||||
| '<|startoftext|>': '<|startoftext|>', | |||||
| '<|endoftext|>': '<|endoftext|>' | |||||
| } | |||||
| self.pat = re.compile( | |||||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||||
| re.IGNORECASE) | |||||
| self.vocab = self.encoder | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token + '</w>' | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except Exception: | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def encode(self, text): | |||||
| bpe_tokens = [] | |||||
| text = whitespace_clean(basic_clean(text)).lower() | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| bpe_tokens.extend(self.encoder[bpe_token] | |||||
| for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||||
| return text | |||||
| def tokenize(self, text): | |||||
| tokens = [] | |||||
| text = whitespace_clean(basic_clean(text)).lower() | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| tokens.extend( | |||||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||||
| return tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| return [self.encoder[bpe_token] for bpe_token in tokens] | |||||
| @@ -0,0 +1,120 @@ | |||||
| # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |||||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """PyTorch BERT model.""" | |||||
| import logging | |||||
| import math | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| logger = logging.getLogger(__name__) | |||||
| def gelu(x): | |||||
| """Implementation of the gelu activation function. | |||||
| For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |||||
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |||||
| """ | |||||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
| def swish(x): | |||||
| return x * torch.sigmoid(x) | |||||
| ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} | |||||
| class LayerNorm(nn.Module): | |||||
| def __init__(self, hidden_size, eps=1e-12): | |||||
| """Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
| """ | |||||
| super(LayerNorm, self).__init__() | |||||
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
| self.variance_epsilon = eps | |||||
| def forward(self, x): | |||||
| u = x.mean(-1, keepdim=True) | |||||
| s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
| return self.weight * x + self.bias | |||||
| class CrossEn(nn.Module): | |||||
| def __init__(self, config=None): | |||||
| super(CrossEn, self).__init__() | |||||
| def forward(self, sim_matrix): | |||||
| logpt = F.log_softmax(sim_matrix, dim=-1) | |||||
| logpt = torch.diag(logpt) | |||||
| nce_loss = -logpt | |||||
| sim_loss = nce_loss.mean() | |||||
| return sim_loss | |||||
| class AllGather(torch.autograd.Function): | |||||
| """An autograd function that performs allgather on a tensor.""" | |||||
| @staticmethod | |||||
| def forward(ctx, tensor, args): | |||||
| if args.world_size == 1: | |||||
| ctx.rank = args.local_rank | |||||
| ctx.batch_size = tensor.shape[0] | |||||
| return tensor | |||||
| else: | |||||
| output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |||||
| torch.distributed.all_gather(output, tensor) | |||||
| ctx.rank = args.local_rank | |||||
| ctx.batch_size = tensor.shape[0] | |||||
| return torch.cat(output, dim=0) | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| return ( | |||||
| grad_output[ctx.batch_size * ctx.rank:ctx.batch_size | |||||
| * (ctx.rank + 1)], | |||||
| None, | |||||
| ) | |||||
| class AllGather2(torch.autograd.Function): | |||||
| """An autograd function that performs allgather on a tensor.""" | |||||
| # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 | |||||
| @staticmethod | |||||
| def forward(ctx, tensor, args): | |||||
| if args.world_size == 1: | |||||
| ctx.rank = args.local_rank | |||||
| ctx.batch_size = tensor.shape[0] | |||||
| return tensor | |||||
| else: | |||||
| output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |||||
| torch.distributed.all_gather(output, tensor) | |||||
| ctx.rank = args.local_rank | |||||
| ctx.batch_size = tensor.shape[0] | |||||
| return torch.cat(output, dim=0) | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| grad_input = grad_output.clone() | |||||
| torch.distributed.all_reduce( | |||||
| grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) | |||||
| return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) | |||||
| * ctx.batch_size], None) | |||||
| @@ -25,9 +25,9 @@ class BertForSequenceClassification(Model): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| import torch | |||||
| from easynlp.appzoo import SequenceClassification | from easynlp.appzoo import SequenceClassification | ||||
| from easynlp.core.predictor import get_model_predictor | from easynlp.core.predictor import get_model_predictor | ||||
| import torch | |||||
| self.model = get_model_predictor( | self.model = get_model_predictor( | ||||
| model_dir=self.model_dir, | model_dir=self.model_dir, | ||||
| model_cls=SequenceClassification, | model_cls=SequenceClassification, | ||||
| @@ -21,7 +21,8 @@ class PalmForTextGeneration(TorchModel): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | |||||
| from sofa.models.palm_v2 import (PalmForConditionalGeneration, | |||||
| Translator) | |||||
| self.model = PalmForConditionalGeneration.from_pretrained(model_dir) | self.model = PalmForConditionalGeneration.from_pretrained(model_dir) | ||||
| self.tokenizer = self.model.tokenizer | self.tokenizer = self.model.tokenizer | ||||
| self.generator = Translator(self.model) | self.generator = Translator(self.model) | ||||
| @@ -27,7 +27,8 @@ class SpaceForDialogIntent(Model): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||||
| from modelscope.trainers.nlp.space.trainer.intent_trainer import \ | |||||
| IntentTrainer | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.config = kwargs.pop( | self.config = kwargs.pop( | ||||
| 'config', | 'config', | ||||
| @@ -22,7 +22,7 @@ class SpaceForDialogStateTracking(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| from sofa.models.space import SpaceForDST, SpaceConfig | |||||
| from sofa.models.space import SpaceConfig, SpaceForDST | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | self.config = SpaceConfig.from_pretrained(self.model_dir) | ||||
| @@ -225,8 +225,8 @@ class MsDataset: | |||||
| continue | continue | ||||
| retained_columns.append(k) | retained_columns.append(k) | ||||
| import torch | |||||
| import math | import math | ||||
| import torch | |||||
| class MsIterableDataset(torch.utils.data.IterableDataset): | class MsIterableDataset(torch.utils.data.IterableDataset): | ||||
| @@ -74,6 +74,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.text_to_image_synthesis: | Tasks.text_to_image_synthesis: | ||||
| (Pipelines.text_to_image_synthesis, | (Pipelines.text_to_image_synthesis, | ||||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | 'damo/cv_imagen_text-to-image-synthesis_tiny'), | ||||
| Tasks.video_multi_modal_embedding: | |||||
| (Pipelines.video_multi_modal_embedding, | |||||
| 'damo/multi_modal_clip_vtretrival_msrvtt_53'), | |||||
| Tasks.image_color_enhance: (Pipelines.image_color_enhance, | Tasks.image_color_enhance: (Pipelines.image_color_enhance, | ||||
| 'damo/cv_csrnet_image-color-enhance-models'), | 'damo/cv_csrnet_image-color-enhance-models'), | ||||
| Tasks.virtual_tryon: (Pipelines.virtual_tryon, | Tasks.virtual_tryon: (Pipelines.virtual_tryon, | ||||
| @@ -3,7 +3,10 @@ try: | |||||
| from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | ||||
| from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline | from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline | ||||
| from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline | from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline | ||||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||||
| from .video_multi_modal_embedding_pipeline import \ | |||||
| VideoMultiModalEmbeddingPipeline | |||||
| from .visual_question_answering_pipeline import \ | |||||
| VisualQuestionAnsweringPipeline | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'torch'": | if str(e) == "No module named 'torch'": | ||||
| pass | pass | ||||
| @@ -0,0 +1,42 @@ | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from ..base import Model, Pipeline | |||||
| from ..builder import PIPELINES | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.video_multi_modal_embedding, | |||||
| module_name=Pipelines.video_multi_modal_embedding) | |||||
| class VideoMultiModalEmbeddingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a video_multi_modal_embedding pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| return input | |||||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||||
| with self.place_device(): | |||||
| out = self.forward(input) | |||||
| self._check_output(out) | |||||
| return out | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return self.model(input) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -15,6 +15,7 @@ try: | |||||
| from .dialog_modeling_pipeline import * # noqa F403 | from .dialog_modeling_pipeline import * # noqa F403 | ||||
| from .dialog_state_tracking_pipeline import * # noqa F403 | from .dialog_state_tracking_pipeline import * # noqa F403 | ||||
| from .fill_mask_pipeline import * # noqa F403 | from .fill_mask_pipeline import * # noqa F403 | ||||
| from .named_entity_recognition_pipeline import * # noqa F403 | |||||
| from .nli_pipeline import * # noqa F403 | from .nli_pipeline import * # noqa F403 | ||||
| from .sentence_similarity_pipeline import * # noqa F403 | from .sentence_similarity_pipeline import * # noqa F403 | ||||
| from .sentiment_classification_pipeline import * # noqa F403 | from .sentiment_classification_pipeline import * # noqa F403 | ||||
| @@ -22,7 +23,6 @@ try: | |||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| from .word_segmentation_pipeline import * # noqa F403 | from .word_segmentation_pipeline import * # noqa F403 | ||||
| from .zero_shot_classification_pipeline import * # noqa F403 | from .zero_shot_classification_pipeline import * # noqa F403 | ||||
| from .named_entity_recognition_pipeline import * # noqa F403 | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'torch'": | if str(e) == "No module named 'torch'": | ||||
| pass | pass | ||||
| @@ -25,7 +25,7 @@ class DialogStateTrackingPreprocessor(Preprocessor): | |||||
| """ | """ | ||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| from sofa.models.space import SpaceTokenizer, SpaceConfig | |||||
| from sofa.models.space import SpaceConfig, SpaceTokenizer | |||||
| self.model_dir: str = model_dir | self.model_dir: str = model_dir | ||||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | self.config = SpaceConfig.from_pretrained(self.model_dir) | ||||
| self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) | self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) | ||||
| @@ -78,7 +78,8 @@ class SequenceClassificationTrainer(BaseTrainer): | |||||
| import torch | import torch | ||||
| from easynlp.appzoo import load_dataset | from easynlp.appzoo import load_dataset | ||||
| from easynlp.appzoo.dataset import GeneralDataset | from easynlp.appzoo.dataset import GeneralDataset | ||||
| from easynlp.appzoo.sequence_classification.model import SequenceClassification | |||||
| from easynlp.appzoo.sequence_classification.model import \ | |||||
| SequenceClassification | |||||
| from easynlp.utils import losses | from easynlp.utils import losses | ||||
| from sklearn.metrics import f1_score | from sklearn.metrics import f1_score | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| @@ -77,6 +77,7 @@ class MultiModalTasks(object): | |||||
| multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
| generative_multi_modal_embedding = 'generative-multi-modal-embedding' | generative_multi_modal_embedding = 'generative-multi-modal-embedding' | ||||
| visual_question_answering = 'visual-question-answering' | visual_question_answering = 'visual-question-answering' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | ||||
| @@ -85,7 +86,6 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||||
| Holds the standard task name to use for identifying different tasks. | Holds the standard task name to use for identifying different tasks. | ||||
| This should be used to register models, pipelines, trainers. | This should be used to register models, pipelines, trainers. | ||||
| """ | """ | ||||
| reverse_field_index = {} | reverse_field_index = {} | ||||
| @staticmethod | @staticmethod | ||||
| @@ -89,8 +89,8 @@ def get_model_type(model_dir): | |||||
| def parse_label_mapping(model_dir): | def parse_label_mapping(model_dir): | ||||
| import os | |||||
| import json | import json | ||||
| import os | |||||
| label2id = None | label2id = None | ||||
| label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING) | label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING) | ||||
| if os.path.exists(label_path): | if os.path.exists(label_path): | ||||
| @@ -70,9 +70,9 @@ def parse_requirements(fname='requirements.txt', with_version=True): | |||||
| CommandLine: | CommandLine: | ||||
| python -c "import setup; print(setup.parse_requirements())" | python -c "import setup; print(setup.parse_requirements())" | ||||
| """ | """ | ||||
| import re | |||||
| import sys | import sys | ||||
| from os.path import exists | from os.path import exists | ||||
| import re | |||||
| require_fpath = fname | require_fpath = fname | ||||
| def parse_line(line): | def parse_line(line): | ||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import numpy as np | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class VideoMultiModalEmbeddingTest(unittest.TestCase): | |||||
| model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' | |||||
| video_path = 'data/test/videos/multi_modal_test_video_9770.mp4' | |||||
| caption = ('a person is connecting something to system', None, None) | |||||
| _input = {'video': video_path, 'text': caption} | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| pipeline_video_multi_modal_embedding = pipeline( | |||||
| Tasks.video_multi_modal_embedding, model=self.model_id) | |||||
| output = pipeline_video_multi_modal_embedding(self._input) | |||||
| logger.info('text feature: {}'.format( | |||||
| output['text_embedding'][0][0][0])) | |||||
| logger.info('video feature: {}'.format( | |||||
| output['video_embedding'][0][0][0])) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_video_multi_modal_embedding = pipeline( | |||||
| task=Tasks.video_multi_modal_embedding) | |||||
| output = pipeline_video_multi_modal_embedding(self._input) | |||||
| logger.info('text feature: {}'.format( | |||||
| output['text_embedding'][0][0][0])) | |||||
| logger.info('video feature: {}'.format( | |||||
| output['video_embedding'][0][0][0])) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||