Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10636269master
| @@ -32,6 +32,7 @@ class Models(object): | |||
| image_reid_person = 'passvitb' | |||
| image_inpainting = 'FFTInpainting' | |||
| video_summarization = 'pgl-video-summarization' | |||
| language_guided_video_summarization = 'clip-it-language-guided-video-summarization' | |||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| @@ -200,6 +201,7 @@ class Pipelines(object): | |||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
| image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| language_guided_video_summarization = 'clip-it-video-summarization' | |||
| image_semantic_segmentation = 'image-semantic-segmentation' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| image_inpainting = 'fft-inpainting' | |||
| @@ -10,10 +10,10 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| movie_scene_segmentation, object_detection, | |||
| product_retrieval_embedding, realtime_object_detection, | |||
| referring_video_object_segmentation, salient_detection, | |||
| shop_segmentation, super_resolution, | |||
| language_guided_video_summarization, movie_scene_segmentation, | |||
| object_detection, product_retrieval_embedding, | |||
| realtime_object_detection, referring_video_object_segmentation, | |||
| salient_detection, shop_segmentation, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| # yapf: enable | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .summarizer import ( | |||
| ClipItVideoSummarization, ) | |||
| else: | |||
| _import_structure = { | |||
| 'summarizer': [ | |||
| 'ClipItVideoSummarization', | |||
| ] | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,194 @@ | |||
| # Part of the implementation is borrowed and modified from BMT and video_features, | |||
| # publicly available at https://github.com/v-iashin/BMT | |||
| # and https://github.com/v-iashin/video_features | |||
| import argparse | |||
| import os | |||
| import os.path as osp | |||
| from copy import deepcopy | |||
| from typing import Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from bmt_clipit.sample.single_video_prediction import (caption_proposals, | |||
| generate_proposals, | |||
| load_cap_model, | |||
| load_prop_model) | |||
| from bmt_clipit.utilities.proposal_utils import non_max_suppresion | |||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
| from videofeatures_clipit.models.i3d.extract_i3d import ExtractI3D | |||
| from videofeatures_clipit.models.vggish.extract_vggish import ExtractVGGish | |||
| from videofeatures_clipit.utils.utils import (fix_tensorflow_gpu_allocation, | |||
| form_list_from_user_input) | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.cv.language_guided_video_summarization.transformer import \ | |||
| Transformer | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def extract_text(args): | |||
| # Loading models and other essential stuff | |||
| cap_cfg, cap_model, train_dataset = load_cap_model( | |||
| args.pretrained_cap_model_path, args.device_id) | |||
| prop_cfg, prop_model = load_prop_model(args.device_id, | |||
| args.prop_generator_model_path, | |||
| args.pretrained_cap_model_path, | |||
| args.max_prop_per_vid) | |||
| # Proposal | |||
| proposals = generate_proposals(prop_model, args.features, | |||
| train_dataset.pad_idx, prop_cfg, | |||
| args.device_id, args.duration_in_secs) | |||
| # NMS if specified | |||
| if args.nms_tiou_thresh is not None: | |||
| proposals = non_max_suppresion(proposals.squeeze(), | |||
| args.nms_tiou_thresh) | |||
| proposals = proposals.unsqueeze(0) | |||
| # Captions for each proposal | |||
| captions = caption_proposals(cap_model, args.features, train_dataset, | |||
| cap_cfg, args.device_id, proposals, | |||
| args.duration_in_secs) | |||
| return captions | |||
| def extract_video_features(video_path, tmp_path, feature_type, i3d_flow_path, | |||
| i3d_rgb_path, kinetics_class_labels, pwc_path, | |||
| vggish_model_path, vggish_pca_path, extraction_fps, | |||
| device): | |||
| default_args = dict( | |||
| device=device, | |||
| extraction_fps=extraction_fps, | |||
| feature_type=feature_type, | |||
| file_with_video_paths=None, | |||
| i3d_flow_path=i3d_flow_path, | |||
| i3d_rgb_path=i3d_rgb_path, | |||
| keep_frames=False, | |||
| kinetics_class_labels=kinetics_class_labels, | |||
| min_side_size=256, | |||
| pwc_path=pwc_path, | |||
| show_kinetics_pred=False, | |||
| stack_size=64, | |||
| step_size=64, | |||
| tmp_path=tmp_path, | |||
| vggish_model_path=vggish_model_path, | |||
| vggish_pca_path=vggish_pca_path, | |||
| ) | |||
| args = argparse.Namespace(**default_args) | |||
| if args.feature_type == 'i3d': | |||
| extractor = ExtractI3D(args) | |||
| elif args.feature_type == 'vggish': | |||
| extractor = ExtractVGGish(args) | |||
| feats = extractor(video_path) | |||
| return feats | |||
| def video_features_to_txt(duration_in_secs, pretrained_cap_model_path, | |||
| prop_generator_model_path, features, device_id): | |||
| default_args = dict( | |||
| device_id=device_id, | |||
| duration_in_secs=duration_in_secs, | |||
| features=features, | |||
| pretrained_cap_model_path=pretrained_cap_model_path, | |||
| prop_generator_model_path=prop_generator_model_path, | |||
| max_prop_per_vid=100, | |||
| nms_tiou_thresh=0.4, | |||
| ) | |||
| args = argparse.Namespace(**default_args) | |||
| txt = extract_text(args) | |||
| return txt | |||
| @MODELS.register_module( | |||
| Tasks.language_guided_video_summarization, | |||
| module_name=Models.language_guided_video_summarization) | |||
| class ClipItVideoSummarization(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the video summarization model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| self.loss = nn.MSELoss() | |||
| self.model = Transformer() | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.model = self.model.to(self._device) | |||
| self.model = self.load_pretrained(self.model, model_path) | |||
| if self.training: | |||
| self.model.train() | |||
| else: | |||
| self.model.eval() | |||
| def load_pretrained(self, net, load_path, strict=True, param_key='params'): | |||
| if isinstance(net, (DataParallel, DistributedDataParallel)): | |||
| net = net.module | |||
| load_net = torch.load( | |||
| load_path, map_location=lambda storage, loc: storage) | |||
| if param_key is not None: | |||
| if param_key not in load_net and 'params' in load_net: | |||
| param_key = 'params' | |||
| logger.info( | |||
| f'Loading: {param_key} does not exist, use params.') | |||
| if param_key in load_net: | |||
| load_net = load_net[param_key] | |||
| logger.info( | |||
| f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||
| ) | |||
| # remove unnecessary 'module.' | |||
| for k, v in deepcopy(load_net).items(): | |||
| if k.startswith('module.'): | |||
| load_net[k[7:]] = v | |||
| load_net.pop(k) | |||
| net.load_state_dict(load_net, strict=strict) | |||
| logger.info('load model done.') | |||
| return net | |||
| def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| frame_features = input['frame_features'] | |||
| txt_features = input['txt_features'] | |||
| gtscore = input['gtscore'] | |||
| preds, attn_weights = self.model(frame_features, txt_features, | |||
| frame_features) | |||
| return {'loss': self.loss(preds, gtscore)} | |||
| def _inference_forward(self, input: Dict[str, | |||
| Tensor]) -> Dict[str, Tensor]: | |||
| frame_features = input['frame_features'] | |||
| txt_features = input['txt_features'] | |||
| y, dec_output = self.model(frame_features, txt_features, | |||
| frame_features) | |||
| return {'scores': y} | |||
| def forward(self, input: Dict[str, | |||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Union[list, Tensor]]: results | |||
| """ | |||
| for key, value in input.items(): | |||
| input[key] = input[key].to(self._device) | |||
| if self.training: | |||
| return self._train_forward(input) | |||
| else: | |||
| return self._inference_forward(input) | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .models import ( | |||
| Transformer, ) | |||
| else: | |||
| _import_structure = { | |||
| 'models': [ | |||
| 'Transformer', | |||
| ] | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,48 @@ | |||
| # Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, | |||
| # publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch | |||
| import torch | |||
| import torch.nn as nn | |||
| from .sub_layers import MultiHeadAttention, PositionwiseFeedForward | |||
| class EncoderLayer(nn.Module): | |||
| """Compose with two layers""" | |||
| def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||
| super(EncoderLayer, self).__init__() | |||
| self.slf_attn = MultiHeadAttention( | |||
| n_head, d_model, d_k, d_v, dropout=dropout) | |||
| self.pos_ffn = PositionwiseFeedForward( | |||
| d_model, d_inner, dropout=dropout) | |||
| def forward(self, enc_input, slf_attn_mask=None): | |||
| enc_output, enc_slf_attn = self.slf_attn( | |||
| enc_input, enc_input, enc_input, mask=slf_attn_mask) | |||
| enc_output = self.pos_ffn(enc_output) | |||
| return enc_output, enc_slf_attn | |||
| class DecoderLayer(nn.Module): | |||
| """Compose with three layers""" | |||
| def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |||
| super(DecoderLayer, self).__init__() | |||
| self.slf_attn = MultiHeadAttention( | |||
| n_head, d_model, d_k, d_v, dropout=dropout) | |||
| self.enc_attn = MultiHeadAttention( | |||
| n_head, d_model, d_k, d_v, dropout=dropout) | |||
| self.pos_ffn = PositionwiseFeedForward( | |||
| d_model, d_inner, dropout=dropout) | |||
| def forward(self, | |||
| dec_input, | |||
| enc_output, | |||
| slf_attn_mask=None, | |||
| dec_enc_attn_mask=None): | |||
| dec_output, dec_slf_attn = self.slf_attn( | |||
| dec_input, dec_input, dec_input, mask=slf_attn_mask) | |||
| dec_output, dec_enc_attn = self.enc_attn( | |||
| dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) | |||
| dec_output = self.pos_ffn(dec_output) | |||
| return dec_output, dec_slf_attn, dec_enc_attn | |||
| @@ -0,0 +1,229 @@ | |||
| # Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, | |||
| # publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from .layers import DecoderLayer, EncoderLayer | |||
| from .sub_layers import MultiHeadAttention | |||
| class PositionalEncoding(nn.Module): | |||
| def __init__(self, d_hid, n_position=200): | |||
| super(PositionalEncoding, self).__init__() | |||
| # Not a parameter | |||
| self.register_buffer( | |||
| 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) | |||
| def _get_sinusoid_encoding_table(self, n_position, d_hid): | |||
| """Sinusoid position encoding table""" | |||
| # TODO: make it with torch instead of numpy | |||
| def get_position_angle_vec(position): | |||
| return [ | |||
| position / np.power(10000, 2 * (hid_j // 2) / d_hid) | |||
| for hid_j in range(d_hid) | |||
| ] | |||
| sinusoid_table = np.array( | |||
| [get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
| return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |||
| def forward(self, x): | |||
| return x + self.pos_table[:, :x.size(1)].clone().detach() | |||
| class Encoder(nn.Module): | |||
| """A encoder model with self attention mechanism.""" | |||
| def __init__(self, | |||
| d_word_vec=1024, | |||
| n_layers=6, | |||
| n_head=8, | |||
| d_k=64, | |||
| d_v=64, | |||
| d_model=512, | |||
| d_inner=2048, | |||
| dropout=0.1, | |||
| n_position=200): | |||
| super().__init__() | |||
| self.position_enc = PositionalEncoding( | |||
| d_word_vec, n_position=n_position) | |||
| self.dropout = nn.Dropout(p=dropout) | |||
| self.layer_stack = nn.ModuleList([ | |||
| EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||
| for _ in range(n_layers) | |||
| ]) | |||
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |||
| self.d_model = d_model | |||
| def forward(self, enc_output, return_attns=False): | |||
| enc_slf_attn_list = [] | |||
| # -- Forward | |||
| enc_output = self.dropout(self.position_enc(enc_output)) | |||
| enc_output = self.layer_norm(enc_output) | |||
| for enc_layer in self.layer_stack: | |||
| enc_output, enc_slf_attn = enc_layer(enc_output) | |||
| enc_slf_attn_list += [enc_slf_attn] if return_attns else [] | |||
| if return_attns: | |||
| return enc_output, enc_slf_attn_list | |||
| return enc_output, | |||
| class Decoder(nn.Module): | |||
| """A decoder model with self attention mechanism.""" | |||
| def __init__(self, | |||
| d_word_vec=1024, | |||
| n_layers=6, | |||
| n_head=8, | |||
| d_k=64, | |||
| d_v=64, | |||
| d_model=512, | |||
| d_inner=2048, | |||
| n_position=200, | |||
| dropout=0.1): | |||
| super().__init__() | |||
| self.position_enc = PositionalEncoding( | |||
| d_word_vec, n_position=n_position) | |||
| self.dropout = nn.Dropout(p=dropout) | |||
| self.layer_stack = nn.ModuleList([ | |||
| DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |||
| for _ in range(n_layers) | |||
| ]) | |||
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |||
| self.d_model = d_model | |||
| def forward(self, | |||
| dec_output, | |||
| enc_output, | |||
| src_mask=None, | |||
| trg_mask=None, | |||
| return_attns=False): | |||
| dec_slf_attn_list, dec_enc_attn_list = [], [] | |||
| # -- Forward | |||
| dec_output = self.dropout(self.position_enc(dec_output)) | |||
| dec_output = self.layer_norm(dec_output) | |||
| for dec_layer in self.layer_stack: | |||
| dec_output, dec_slf_attn, dec_enc_attn = dec_layer( | |||
| dec_output, | |||
| enc_output, | |||
| slf_attn_mask=trg_mask, | |||
| dec_enc_attn_mask=src_mask) | |||
| dec_slf_attn_list += [dec_slf_attn] if return_attns else [] | |||
| dec_enc_attn_list += [dec_enc_attn] if return_attns else [] | |||
| if return_attns: | |||
| return dec_output, dec_slf_attn_list, dec_enc_attn_list | |||
| return dec_output, | |||
| class Transformer(nn.Module): | |||
| """A sequence to sequence model with attention mechanism.""" | |||
| def __init__(self, | |||
| num_sentence=7, | |||
| txt_atten_head=4, | |||
| d_frame_vec=512, | |||
| d_model=512, | |||
| d_inner=2048, | |||
| n_layers=6, | |||
| n_head=8, | |||
| d_k=256, | |||
| d_v=256, | |||
| dropout=0.1, | |||
| n_position=4000): | |||
| super().__init__() | |||
| self.d_model = d_model | |||
| self.layer_norm_img_src = nn.LayerNorm(d_frame_vec, eps=1e-6) | |||
| self.layer_norm_img_trg = nn.LayerNorm(d_frame_vec, eps=1e-6) | |||
| self.layer_norm_txt = nn.LayerNorm( | |||
| num_sentence * d_frame_vec, eps=1e-6) | |||
| self.linear_txt = nn.Linear( | |||
| in_features=num_sentence * d_frame_vec, out_features=d_model) | |||
| self.lg_attention = MultiHeadAttention( | |||
| n_head=txt_atten_head, d_model=d_model, d_k=d_k, d_v=d_v) | |||
| self.encoder = Encoder( | |||
| n_position=n_position, | |||
| d_word_vec=d_frame_vec, | |||
| d_model=d_model, | |||
| d_inner=d_inner, | |||
| n_layers=n_layers, | |||
| n_head=n_head, | |||
| d_k=d_k, | |||
| d_v=d_v, | |||
| dropout=dropout) | |||
| self.decoder = Decoder( | |||
| n_position=n_position, | |||
| d_word_vec=d_frame_vec, | |||
| d_model=d_model, | |||
| d_inner=d_inner, | |||
| n_layers=n_layers, | |||
| n_head=n_head, | |||
| d_k=d_k, | |||
| d_v=d_v, | |||
| dropout=dropout) | |||
| for p in self.parameters(): | |||
| if p.dim() > 1: | |||
| nn.init.xavier_uniform_(p) | |||
| assert d_model == d_frame_vec, 'the dimensions of all module outputs shall be the same.' | |||
| self.linear_1 = nn.Linear(in_features=d_model, out_features=d_model) | |||
| self.linear_2 = nn.Linear( | |||
| in_features=self.linear_1.out_features, out_features=1) | |||
| self.drop = nn.Dropout(p=0.5) | |||
| self.norm_y = nn.LayerNorm(normalized_shape=d_model, eps=1e-6) | |||
| self.norm_linear = nn.LayerNorm( | |||
| normalized_shape=self.linear_1.out_features, eps=1e-6) | |||
| self.relu = nn.ReLU() | |||
| self.sigmoid = nn.Sigmoid() | |||
| def forward(self, src_seq, src_txt, trg_seq): | |||
| features_txt = self.linear_txt(src_txt) | |||
| atten_seq, txt_attn = self.lg_attention(src_seq, features_txt, | |||
| features_txt) | |||
| enc_output, *_ = self.encoder(atten_seq) | |||
| dec_output, *_ = self.decoder(trg_seq, enc_output) | |||
| y = self.drop(enc_output) | |||
| y = self.norm_y(y) | |||
| # 2-layer NN (Regressor Network) | |||
| y = self.linear_1(y) | |||
| y = self.relu(y) | |||
| y = self.drop(y) | |||
| y = self.norm_linear(y) | |||
| y = self.linear_2(y) | |||
| y = self.sigmoid(y) | |||
| y = y.view(1, -1) | |||
| return y, dec_output | |||
| @@ -0,0 +1,27 @@ | |||
| # Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, | |||
| # publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class ScaledDotProductAttention(nn.Module): | |||
| """Scaled Dot-Product Attention""" | |||
| def __init__(self, temperature, attn_dropout=0.1): | |||
| super().__init__() | |||
| self.temperature = temperature | |||
| self.dropout = nn.Dropout(attn_dropout) | |||
| def forward(self, q, k, v, mask=None): | |||
| attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) | |||
| if mask is not None: | |||
| attn = attn.masked_fill(mask == 0, -1e9) | |||
| attn = self.dropout(F.softmax(attn, dim=-1)) | |||
| output = torch.matmul(attn, v) | |||
| return output, attn | |||
| @@ -0,0 +1,83 @@ | |||
| # Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, | |||
| # publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch | |||
| import numpy as np | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from .modules import ScaledDotProductAttention | |||
| class MultiHeadAttention(nn.Module): | |||
| """Multi-Head Attention module""" | |||
| def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |||
| super().__init__() | |||
| self.n_head = n_head | |||
| self.d_k = d_k | |||
| self.d_v = d_v | |||
| self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) | |||
| self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) | |||
| self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) | |||
| self.fc = nn.Linear(n_head * d_v, d_model, bias=False) | |||
| self.attention = ScaledDotProductAttention(temperature=d_k**0.5) | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |||
| def forward(self, q, k, v, mask=None): | |||
| d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |||
| sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) | |||
| residual = q | |||
| # Pass through the pre-attention projection: b x lq x (n*dv) | |||
| # Separate different heads: b x lq x n x dv | |||
| q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) | |||
| k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |||
| v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |||
| # Transpose for attention dot product: b x n x lq x dv | |||
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(1) # For head axis broadcasting. | |||
| q, attn = self.attention(q, k, v, mask=mask) | |||
| # Transpose to move the head dimension back: b x lq x n x dv | |||
| # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) | |||
| q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) | |||
| q = self.dropout(self.fc(q)) | |||
| q += residual | |||
| q = self.layer_norm(q) | |||
| return q, attn | |||
| class PositionwiseFeedForward(nn.Module): | |||
| """A two-feed-forward-layer module""" | |||
| def __init__(self, d_in, d_hid, dropout=0.1): | |||
| super().__init__() | |||
| self.w_1 = nn.Linear(d_in, d_hid) # position-wise | |||
| self.w_2 = nn.Linear(d_hid, d_in) # position-wise | |||
| self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) | |||
| self.dropout = nn.Dropout(dropout) | |||
| def forward(self, x): | |||
| residual = x | |||
| x = self.w_2(F.relu(self.w_1(x))) | |||
| x = self.dropout(x) | |||
| x += residual | |||
| x = self.layer_norm(x) | |||
| return x | |||
| @@ -59,6 +59,7 @@ if TYPE_CHECKING: | |||
| from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | |||
| from .hand_static_pipeline import HandStaticPipeline | |||
| from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline | |||
| from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -132,6 +133,9 @@ else: | |||
| 'referring_video_object_segmentation_pipeline': [ | |||
| 'ReferringVideoObjectSegmentationPipeline' | |||
| ], | |||
| 'language_guided_video_summarization_pipeline': [ | |||
| 'LanguageGuidedVideoSummarizationPipeline' | |||
| ] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,250 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import random | |||
| import shutil | |||
| import tempfile | |||
| from typing import Any, Dict | |||
| import clip | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.language_guided_video_summarization import \ | |||
| ClipItVideoSummarization | |||
| from modelscope.models.cv.language_guided_video_summarization.summarizer import ( | |||
| extract_video_features, video_features_to_txt) | |||
| from modelscope.models.cv.video_summarization import summary_format | |||
| from modelscope.models.cv.video_summarization.summarizer import ( | |||
| generate_summary, get_change_points) | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.language_guided_video_summarization, | |||
| module_name=Pipelines.language_guided_video_summarization) | |||
| class LanguageGuidedVideoSummarizationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a language guided video summarization pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||
| logger.info(f'loading model from {model}') | |||
| self.model_dir = model | |||
| self.tmp_dir = kwargs.get('tmp_dir', None) | |||
| if self.tmp_dir is None: | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| config_path = osp.join(model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.clip_model, self.clip_preprocess = clip.load( | |||
| 'ViT-B/32', | |||
| device=self.device, | |||
| download_root=os.path.join(self.model_dir, 'clip')) | |||
| self.clipit_model = ClipItVideoSummarization(model) | |||
| self.clipit_model = self.clipit_model.to(self.device).eval() | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if not isinstance(input, tuple): | |||
| raise TypeError(f'input should be a str,' | |||
| f' but got {type(input)}') | |||
| video_path, sentences = input | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| frames = [] | |||
| picks = [] | |||
| cap = cv2.VideoCapture(video_path) | |||
| self.fps = cap.get(cv2.CAP_PROP_FPS) | |||
| self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |||
| frame_idx = 0 | |||
| # extract 1 frame every 15 frames in the video and save the frame index | |||
| while (cap.isOpened()): | |||
| ret, frame = cap.read() | |||
| if not ret: | |||
| break | |||
| if frame_idx % 15 == 0: | |||
| frames.append(frame) | |||
| picks.append(frame_idx) | |||
| frame_idx += 1 | |||
| n_frame = frame_idx | |||
| if sentences is None: | |||
| logger.info('input sentences is none, using sentences from video!') | |||
| tmp_path = os.path.join(self.tmp_dir, 'tmp') | |||
| i3d_flow_path = os.path.join(self.model_dir, 'i3d/i3d_flow.pt') | |||
| i3d_rgb_path = os.path.join(self.model_dir, 'i3d/i3d_rgb.pt') | |||
| kinetics_class_labels = os.path.join(self.model_dir, | |||
| 'i3d/label_map.txt') | |||
| pwc_path = os.path.join(self.model_dir, 'i3d/pwc_net.pt') | |||
| vggish_model_path = os.path.join(self.model_dir, | |||
| 'vggish/vggish_model.ckpt') | |||
| vggish_pca_path = os.path.join(self.model_dir, | |||
| 'vggish/vggish_pca_params.npz') | |||
| device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| i3d_feats = extract_video_features( | |||
| video_path=video_path, | |||
| feature_type='i3d', | |||
| tmp_path=tmp_path, | |||
| i3d_flow_path=i3d_flow_path, | |||
| i3d_rgb_path=i3d_rgb_path, | |||
| kinetics_class_labels=kinetics_class_labels, | |||
| pwc_path=pwc_path, | |||
| vggish_model_path=vggish_model_path, | |||
| vggish_pca_path=vggish_pca_path, | |||
| extraction_fps=2, | |||
| device=device) | |||
| rgb = i3d_feats['rgb'] | |||
| flow = i3d_feats['flow'] | |||
| device = '/gpu:0' if torch.cuda.is_available() else '/cpu:0' | |||
| vggish = extract_video_features( | |||
| video_path=video_path, | |||
| feature_type='vggish', | |||
| tmp_path=tmp_path, | |||
| i3d_flow_path=i3d_flow_path, | |||
| i3d_rgb_path=i3d_rgb_path, | |||
| kinetics_class_labels=kinetics_class_labels, | |||
| pwc_path=pwc_path, | |||
| vggish_model_path=vggish_model_path, | |||
| vggish_pca_path=vggish_pca_path, | |||
| extraction_fps=2, | |||
| device=device) | |||
| audio = vggish['audio'] | |||
| duration_in_secs = float(self.frame_count) / self.fps | |||
| txt = video_features_to_txt( | |||
| duration_in_secs=duration_in_secs, | |||
| pretrained_cap_model_path=os.path.join( | |||
| self.model_dir, 'bmt/sample/best_cap_model.pt'), | |||
| prop_generator_model_path=os.path.join( | |||
| self.model_dir, 'bmt/sample/best_prop_model.pt'), | |||
| features={ | |||
| 'rgb': rgb, | |||
| 'flow': flow, | |||
| 'audio': audio | |||
| }, | |||
| device_id=0) | |||
| sentences = [item['sentence'] for item in txt] | |||
| clip_image_features = [] | |||
| for frame in frames: | |||
| x = self.clip_preprocess( | |||
| Image.fromarray(cv2.cvtColor( | |||
| frame, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(self.device) | |||
| with torch.no_grad(): | |||
| f = self.clip_model.encode_image(x).squeeze(0).cpu().numpy() | |||
| clip_image_features.append(f) | |||
| clip_txt_features = [] | |||
| for sentence in sentences: | |||
| text_input = clip.tokenize(sentence).to(self.device) | |||
| with torch.no_grad(): | |||
| text_feature = self.clip_model.encode_text(text_input).squeeze( | |||
| 0).cpu().numpy() | |||
| clip_txt_features.append(text_feature) | |||
| clip_txt_features = self.sample_txt_feateures(clip_txt_features) | |||
| clip_txt_features = np.array(clip_txt_features).reshape((1, -1)) | |||
| result = { | |||
| 'video_name': video_path, | |||
| 'clip_image_features': np.array(clip_image_features), | |||
| 'clip_txt_features': np.array(clip_txt_features), | |||
| 'n_frame': n_frame, | |||
| 'picks': np.array(picks) | |||
| } | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| clip_image_features = input['clip_image_features'] | |||
| clip_txt_features = input['clip_txt_features'] | |||
| clip_image_features = self.norm_feature(clip_image_features) | |||
| clip_txt_features = self.norm_feature(clip_txt_features) | |||
| change_points, n_frame_per_seg = get_change_points( | |||
| clip_image_features, input['n_frame']) | |||
| summary = self.inference(clip_image_features, clip_txt_features, | |||
| input['n_frame'], input['picks'], | |||
| change_points) | |||
| output = summary_format(summary, self.fps) | |||
| return {OutputKeys.OUTPUT: output} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| if os.path.exists(self.tmp_dir): | |||
| shutil.rmtree(self.tmp_dir) | |||
| return inputs | |||
| def inference(self, clip_image_features, clip_txt_features, n_frames, | |||
| picks, change_points): | |||
| clip_image_features = torch.from_numpy( | |||
| np.array(clip_image_features, np.float32)).unsqueeze(0) | |||
| clip_txt_features = torch.from_numpy( | |||
| np.array(clip_txt_features, np.float32)).unsqueeze(0) | |||
| picks = np.array(picks, np.int32) | |||
| with torch.no_grad(): | |||
| results = self.clipit_model( | |||
| dict( | |||
| frame_features=clip_image_features, | |||
| txt_features=clip_txt_features)) | |||
| scores = results['scores'] | |||
| if not scores.device.type == 'cpu': | |||
| scores = scores.cpu() | |||
| scores = scores.squeeze(0).numpy().tolist() | |||
| summary = generate_summary([change_points], [scores], [n_frames], | |||
| [picks])[0] | |||
| return summary.tolist() | |||
| def sample_txt_feateures(self, feat, num=7): | |||
| while len(feat) < num: | |||
| feat.append(feat[-1]) | |||
| idxes = list(np.arange(0, len(feat))) | |||
| samples_idx = [] | |||
| for ii in range(num): | |||
| idx = random.choice(idxes) | |||
| while idx in samples_idx: | |||
| idx = random.choice(idxes) | |||
| samples_idx.append(idx) | |||
| samples_idx.sort() | |||
| samples = [] | |||
| for idx in samples_idx: | |||
| samples.append(feat[idx]) | |||
| return samples | |||
| def norm_feature(self, frames_feat): | |||
| for ii in range(len(frames_feat)): | |||
| frame_feat = frames_feat[ii] | |||
| frames_feat[ii] = frame_feat / np.linalg.norm(frame_feat) | |||
| frames_feat = frames_feat.reshape((frames_feat.shape[0], -1)) | |||
| return frames_feat | |||
| @@ -80,6 +80,7 @@ class CVTasks(object): | |||
| video_embedding = 'video-embedding' | |||
| virtual_try_on = 'virtual-try-on' | |||
| movie_scene_segmentation = 'movie-scene-segmentation' | |||
| language_guided_video_summarization = 'language-guided-video-summarization' | |||
| # video segmentation | |||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
| @@ -1,5 +1,7 @@ | |||
| albumentations>=1.0.3 | |||
| av>=9.2.0 | |||
| bmt_clipit>=1.0 | |||
| clip>=1.0 | |||
| easydict | |||
| fairscale>=0.4.1 | |||
| fastai>=1.0.51 | |||
| @@ -33,3 +35,4 @@ tf_slim | |||
| timm>=0.4.9 | |||
| torchmetrics>=0.6.2 | |||
| torchvision | |||
| videofeatures_clipit>=1.0 | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import torch | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class LanguageGuidedVideoSummarizationTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.language_guided_video_summarization | |||
| self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| video_path = 'data/test/videos/video_category_test_video.mp4' | |||
| # input can be sentences such as sentences=['phone', 'hand'], or sentences=None | |||
| sentences = None | |||
| summarization_pipeline = pipeline( | |||
| Tasks.language_guided_video_summarization, model=self.model_id) | |||
| result = summarization_pipeline((video_path, sentences)) | |||
| print(f'video summarization output: \n{result}.') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| video_path = 'data/test/videos/video_category_test_video.mp4' | |||
| summarization_pipeline = pipeline( | |||
| Tasks.language_guided_video_summarization) | |||
| result = summarization_pipeline(video_path) | |||
| print(f'video summarization output:\n {result}.') | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||