From 541e460f8bab0ba1d8301ac2d31c028578764adf Mon Sep 17 00:00:00 2001 From: "james.wjg" Date: Wed, 16 Nov 2022 15:20:26 +0800 Subject: [PATCH] add support for cv/language_guided_video_summarization Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10636269 --- modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 8 +- .../__init__.py | 25 ++ .../summarizer.py | 194 ++++++++++++++ .../transformer/__init__.py | 25 ++ .../transformer/layers.py | 48 ++++ .../transformer/models.py | 229 ++++++++++++++++ .../transformer/modules.py | 27 ++ .../transformer/sub_layers.py | 83 ++++++ modelscope/pipelines/cv/__init__.py | 4 + ...age_guided_video_summarization_pipeline.py | 250 ++++++++++++++++++ modelscope/utils/constant.py | 1 + requirements/cv.txt | 3 + ...est_language_guided_video_summarization.py | 49 ++++ 14 files changed, 944 insertions(+), 4 deletions(-) create mode 100755 modelscope/models/cv/language_guided_video_summarization/__init__.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/summarizer.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/transformer/layers.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/transformer/models.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/transformer/modules.py create mode 100755 modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py create mode 100755 modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py create mode 100755 tests/pipelines/test_language_guided_video_summarization.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index c7c3e729..ccd36349 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -32,6 +32,7 @@ class Models(object): image_reid_person = 'passvitb' image_inpainting = 'FFTInpainting' video_summarization = 'pgl-video-summarization' + language_guided_video_summarization = 'clip-it-language-guided-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' text_driven_segmentation = 'text-driven-segmentation' @@ -200,6 +201,7 @@ class Pipelines(object): video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' image_panoptic_segmentation = 'image-panoptic-segmentation' video_summarization = 'googlenet_pgl_video_summarization' + language_guided_video_summarization = 'clip-it-video-summarization' image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' image_inpainting = 'fft-inpainting' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 64039863..de972032 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -10,10 +10,10 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_panoptic_segmentation, image_portrait_enhancement, image_reid_person, image_semantic_segmentation, image_to_image_generation, image_to_image_translation, - movie_scene_segmentation, object_detection, - product_retrieval_embedding, realtime_object_detection, - referring_video_object_segmentation, salient_detection, - shop_segmentation, super_resolution, + language_guided_video_summarization, movie_scene_segmentation, + object_detection, product_retrieval_embedding, + realtime_object_detection, referring_video_object_segmentation, + salient_detection, shop_segmentation, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/language_guided_video_summarization/__init__.py b/modelscope/models/cv/language_guided_video_summarization/__init__.py new file mode 100755 index 00000000..73f7bd03 --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .summarizer import ( + ClipItVideoSummarization, ) + +else: + _import_structure = { + 'summarizer': [ + 'ClipItVideoSummarization', + ] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/language_guided_video_summarization/summarizer.py b/modelscope/models/cv/language_guided_video_summarization/summarizer.py new file mode 100755 index 00000000..654dc3ea --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/summarizer.py @@ -0,0 +1,194 @@ +# Part of the implementation is borrowed and modified from BMT and video_features, +# publicly available at https://github.com/v-iashin/BMT +# and https://github.com/v-iashin/video_features + +import argparse +import os +import os.path as osp +from copy import deepcopy +from typing import Dict, Union + +import numpy as np +import torch +import torch.nn as nn +from bmt_clipit.sample.single_video_prediction import (caption_proposals, + generate_proposals, + load_cap_model, + load_prop_model) +from bmt_clipit.utilities.proposal_utils import non_max_suppresion +from torch.nn.parallel import DataParallel, DistributedDataParallel +from videofeatures_clipit.models.i3d.extract_i3d import ExtractI3D +from videofeatures_clipit.models.vggish.extract_vggish import ExtractVGGish +from videofeatures_clipit.utils.utils import (fix_tensorflow_gpu_allocation, + form_list_from_user_input) + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.language_guided_video_summarization.transformer import \ + Transformer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def extract_text(args): + # Loading models and other essential stuff + cap_cfg, cap_model, train_dataset = load_cap_model( + args.pretrained_cap_model_path, args.device_id) + prop_cfg, prop_model = load_prop_model(args.device_id, + args.prop_generator_model_path, + args.pretrained_cap_model_path, + args.max_prop_per_vid) + # Proposal + proposals = generate_proposals(prop_model, args.features, + train_dataset.pad_idx, prop_cfg, + args.device_id, args.duration_in_secs) + # NMS if specified + if args.nms_tiou_thresh is not None: + proposals = non_max_suppresion(proposals.squeeze(), + args.nms_tiou_thresh) + proposals = proposals.unsqueeze(0) + # Captions for each proposal + captions = caption_proposals(cap_model, args.features, train_dataset, + cap_cfg, args.device_id, proposals, + args.duration_in_secs) + return captions + + +def extract_video_features(video_path, tmp_path, feature_type, i3d_flow_path, + i3d_rgb_path, kinetics_class_labels, pwc_path, + vggish_model_path, vggish_pca_path, extraction_fps, + device): + default_args = dict( + device=device, + extraction_fps=extraction_fps, + feature_type=feature_type, + file_with_video_paths=None, + i3d_flow_path=i3d_flow_path, + i3d_rgb_path=i3d_rgb_path, + keep_frames=False, + kinetics_class_labels=kinetics_class_labels, + min_side_size=256, + pwc_path=pwc_path, + show_kinetics_pred=False, + stack_size=64, + step_size=64, + tmp_path=tmp_path, + vggish_model_path=vggish_model_path, + vggish_pca_path=vggish_pca_path, + ) + args = argparse.Namespace(**default_args) + + if args.feature_type == 'i3d': + extractor = ExtractI3D(args) + elif args.feature_type == 'vggish': + extractor = ExtractVGGish(args) + + feats = extractor(video_path) + return feats + + +def video_features_to_txt(duration_in_secs, pretrained_cap_model_path, + prop_generator_model_path, features, device_id): + default_args = dict( + device_id=device_id, + duration_in_secs=duration_in_secs, + features=features, + pretrained_cap_model_path=pretrained_cap_model_path, + prop_generator_model_path=prop_generator_model_path, + max_prop_per_vid=100, + nms_tiou_thresh=0.4, + ) + args = argparse.Namespace(**default_args) + txt = extract_text(args) + return txt + + +@MODELS.register_module( + Tasks.language_guided_video_summarization, + module_name=Models.language_guided_video_summarization) +class ClipItVideoSummarization(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the video summarization model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.loss = nn.MSELoss() + self.model = Transformer() + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device) + + self.model = self.load_pretrained(self.model, model_path) + + if self.training: + self.model.train() + else: + self.model.eval() + + def load_pretrained(self, net, load_path, strict=True, param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + txt_features = input['txt_features'] + gtscore = input['gtscore'] + preds, attn_weights = self.model(frame_features, txt_features, + frame_features) + return {'loss': self.loss(preds, gtscore)} + + def _inference_forward(self, input: Dict[str, + Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + txt_features = input['txt_features'] + y, dec_output = self.model(frame_features, txt_features, + frame_features) + return {'scores': y} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + if self.training: + return self._train_forward(input) + else: + return self._inference_forward(input) diff --git a/modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py b/modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py new file mode 100755 index 00000000..68dccccf --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .models import ( + Transformer, ) + +else: + _import_structure = { + 'models': [ + 'Transformer', + ] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/language_guided_video_summarization/transformer/layers.py b/modelscope/models/cv/language_guided_video_summarization/transformer/layers.py new file mode 100755 index 00000000..6782c209 --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/transformer/layers.py @@ -0,0 +1,48 @@ +# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, +# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch +import torch +import torch.nn as nn + +from .sub_layers import MultiHeadAttention, PositionwiseFeedForward + + +class EncoderLayer(nn.Module): + """Compose with two layers""" + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output = self.pos_ffn(enc_output) + return enc_output, enc_slf_attn + + +class DecoderLayer(nn.Module): + """Compose with three layers""" + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(DecoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout) + + def forward(self, + dec_input, + enc_output, + slf_attn_mask=None, + dec_enc_attn_mask=None): + dec_output, dec_slf_attn = self.slf_attn( + dec_input, dec_input, dec_input, mask=slf_attn_mask) + dec_output, dec_enc_attn = self.enc_attn( + dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + dec_output = self.pos_ffn(dec_output) + return dec_output, dec_slf_attn, dec_enc_attn diff --git a/modelscope/models/cv/language_guided_video_summarization/transformer/models.py b/modelscope/models/cv/language_guided_video_summarization/transformer/models.py new file mode 100755 index 00000000..f4ae34ee --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/transformer/models.py @@ -0,0 +1,229 @@ +# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, +# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch + +import numpy as np +import torch +import torch.nn as nn + +from .layers import DecoderLayer, EncoderLayer +from .sub_layers import MultiHeadAttention + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid, n_position=200): + super(PositionalEncoding, self).__init__() + + # Not a parameter + self.register_buffer( + 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table""" + + # TODO: make it with torch instead of numpy + + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def forward(self, x): + return x + self.pos_table[:, :x.size(1)].clone().detach() + + +class Encoder(nn.Module): + """A encoder model with self attention mechanism.""" + + def __init__(self, + d_word_vec=1024, + n_layers=6, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=2048, + dropout=0.1, + n_position=200): + + super().__init__() + + self.position_enc = PositionalEncoding( + d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.d_model = d_model + + def forward(self, enc_output, return_attns=False): + + enc_slf_attn_list = [] + # -- Forward + enc_output = self.dropout(self.position_enc(enc_output)) + enc_output = self.layer_norm(enc_output) + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer(enc_output) + enc_slf_attn_list += [enc_slf_attn] if return_attns else [] + + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output, + + +class Decoder(nn.Module): + """A decoder model with self attention mechanism.""" + + def __init__(self, + d_word_vec=1024, + n_layers=6, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=2048, + n_position=200, + dropout=0.1): + + super().__init__() + + self.position_enc = PositionalEncoding( + d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.d_model = d_model + + def forward(self, + dec_output, + enc_output, + src_mask=None, + trg_mask=None, + return_attns=False): + + dec_slf_attn_list, dec_enc_attn_list = [], [] + + # -- Forward + dec_output = self.dropout(self.position_enc(dec_output)) + dec_output = self.layer_norm(dec_output) + + for dec_layer in self.layer_stack: + dec_output, dec_slf_attn, dec_enc_attn = dec_layer( + dec_output, + enc_output, + slf_attn_mask=trg_mask, + dec_enc_attn_mask=src_mask) + dec_slf_attn_list += [dec_slf_attn] if return_attns else [] + dec_enc_attn_list += [dec_enc_attn] if return_attns else [] + + if return_attns: + return dec_output, dec_slf_attn_list, dec_enc_attn_list + return dec_output, + + +class Transformer(nn.Module): + """A sequence to sequence model with attention mechanism.""" + + def __init__(self, + num_sentence=7, + txt_atten_head=4, + d_frame_vec=512, + d_model=512, + d_inner=2048, + n_layers=6, + n_head=8, + d_k=256, + d_v=256, + dropout=0.1, + n_position=4000): + + super().__init__() + + self.d_model = d_model + + self.layer_norm_img_src = nn.LayerNorm(d_frame_vec, eps=1e-6) + self.layer_norm_img_trg = nn.LayerNorm(d_frame_vec, eps=1e-6) + self.layer_norm_txt = nn.LayerNorm( + num_sentence * d_frame_vec, eps=1e-6) + + self.linear_txt = nn.Linear( + in_features=num_sentence * d_frame_vec, out_features=d_model) + self.lg_attention = MultiHeadAttention( + n_head=txt_atten_head, d_model=d_model, d_k=d_k, d_v=d_v) + + self.encoder = Encoder( + n_position=n_position, + d_word_vec=d_frame_vec, + d_model=d_model, + d_inner=d_inner, + n_layers=n_layers, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout) + + self.decoder = Decoder( + n_position=n_position, + d_word_vec=d_frame_vec, + d_model=d_model, + d_inner=d_inner, + n_layers=n_layers, + n_head=n_head, + d_k=d_k, + d_v=d_v, + dropout=dropout) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + assert d_model == d_frame_vec, 'the dimensions of all module outputs shall be the same.' + + self.linear_1 = nn.Linear(in_features=d_model, out_features=d_model) + self.linear_2 = nn.Linear( + in_features=self.linear_1.out_features, out_features=1) + + self.drop = nn.Dropout(p=0.5) + self.norm_y = nn.LayerNorm(normalized_shape=d_model, eps=1e-6) + self.norm_linear = nn.LayerNorm( + normalized_shape=self.linear_1.out_features, eps=1e-6) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, src_seq, src_txt, trg_seq): + + features_txt = self.linear_txt(src_txt) + atten_seq, txt_attn = self.lg_attention(src_seq, features_txt, + features_txt) + + enc_output, *_ = self.encoder(atten_seq) + dec_output, *_ = self.decoder(trg_seq, enc_output) + + y = self.drop(enc_output) + y = self.norm_y(y) + + # 2-layer NN (Regressor Network) + y = self.linear_1(y) + y = self.relu(y) + y = self.drop(y) + y = self.norm_linear(y) + + y = self.linear_2(y) + y = self.sigmoid(y) + y = y.view(1, -1) + + return y, dec_output diff --git a/modelscope/models/cv/language_guided_video_summarization/transformer/modules.py b/modelscope/models/cv/language_guided_video_summarization/transformer/modules.py new file mode 100755 index 00000000..03ef8eaf --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/transformer/modules.py @@ -0,0 +1,27 @@ +# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, +# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention""" + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn diff --git a/modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py b/modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py new file mode 100755 index 00000000..42e10abb --- /dev/null +++ b/modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py @@ -0,0 +1,83 @@ +# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch, +# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch + +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from .modules import ScaledDotProductAttention + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module""" + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k**0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + residual = q + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) # For head axis broadcasting. + + q, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q = self.dropout(self.fc(q)) + q += residual + + q = self.layer_norm(q) + + return q, attn + + +class PositionwiseFeedForward(nn.Module): + """A two-feed-forward-layer module""" + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) # position-wise + self.w_2 = nn.Linear(d_hid, d_in) # position-wise + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + + x = self.w_2(F.relu(self.w_1(x))) + x = self.dropout(x) + x += residual + + x = self.layer_norm(x) + + return x diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 97cd8761..5e9220bd 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -59,6 +59,7 @@ if TYPE_CHECKING: from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin from .hand_static_pipeline import HandStaticPipeline from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline + from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline else: _import_structure = { @@ -132,6 +133,9 @@ else: 'referring_video_object_segmentation_pipeline': [ 'ReferringVideoObjectSegmentationPipeline' ], + 'language_guided_video_summarization_pipeline': [ + 'LanguageGuidedVideoSummarizationPipeline' + ] } import sys diff --git a/modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py b/modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py new file mode 100755 index 00000000..059dadb7 --- /dev/null +++ b/modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py @@ -0,0 +1,250 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +import random +import shutil +import tempfile +from typing import Any, Dict + +import clip +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.language_guided_video_summarization import \ + ClipItVideoSummarization +from modelscope.models.cv.language_guided_video_summarization.summarizer import ( + extract_video_features, video_features_to_txt) +from modelscope.models.cv.video_summarization import summary_format +from modelscope.models.cv.video_summarization.summarizer import ( + generate_summary, get_change_points) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.language_guided_video_summarization, + module_name=Pipelines.language_guided_video_summarization) +class LanguageGuidedVideoSummarizationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a language guided video summarization pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from {model}') + self.model_dir = model + + self.tmp_dir = kwargs.get('tmp_dir', None) + if self.tmp_dir is None: + self.tmp_dir = tempfile.TemporaryDirectory().name + + config_path = osp.join(model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + + self.clip_model, self.clip_preprocess = clip.load( + 'ViT-B/32', + device=self.device, + download_root=os.path.join(self.model_dir, 'clip')) + + self.clipit_model = ClipItVideoSummarization(model) + self.clipit_model = self.clipit_model.to(self.device).eval() + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if not isinstance(input, tuple): + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + + video_path, sentences = input + + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + frames = [] + picks = [] + cap = cv2.VideoCapture(video_path) + self.fps = cap.get(cv2.CAP_PROP_FPS) + self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + frame_idx = 0 + # extract 1 frame every 15 frames in the video and save the frame index + while (cap.isOpened()): + ret, frame = cap.read() + if not ret: + break + if frame_idx % 15 == 0: + frames.append(frame) + picks.append(frame_idx) + frame_idx += 1 + n_frame = frame_idx + + if sentences is None: + logger.info('input sentences is none, using sentences from video!') + + tmp_path = os.path.join(self.tmp_dir, 'tmp') + i3d_flow_path = os.path.join(self.model_dir, 'i3d/i3d_flow.pt') + i3d_rgb_path = os.path.join(self.model_dir, 'i3d/i3d_rgb.pt') + kinetics_class_labels = os.path.join(self.model_dir, + 'i3d/label_map.txt') + pwc_path = os.path.join(self.model_dir, 'i3d/pwc_net.pt') + vggish_model_path = os.path.join(self.model_dir, + 'vggish/vggish_model.ckpt') + vggish_pca_path = os.path.join(self.model_dir, + 'vggish/vggish_pca_params.npz') + + device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + i3d_feats = extract_video_features( + video_path=video_path, + feature_type='i3d', + tmp_path=tmp_path, + i3d_flow_path=i3d_flow_path, + i3d_rgb_path=i3d_rgb_path, + kinetics_class_labels=kinetics_class_labels, + pwc_path=pwc_path, + vggish_model_path=vggish_model_path, + vggish_pca_path=vggish_pca_path, + extraction_fps=2, + device=device) + rgb = i3d_feats['rgb'] + flow = i3d_feats['flow'] + + device = '/gpu:0' if torch.cuda.is_available() else '/cpu:0' + vggish = extract_video_features( + video_path=video_path, + feature_type='vggish', + tmp_path=tmp_path, + i3d_flow_path=i3d_flow_path, + i3d_rgb_path=i3d_rgb_path, + kinetics_class_labels=kinetics_class_labels, + pwc_path=pwc_path, + vggish_model_path=vggish_model_path, + vggish_pca_path=vggish_pca_path, + extraction_fps=2, + device=device) + audio = vggish['audio'] + + duration_in_secs = float(self.frame_count) / self.fps + + txt = video_features_to_txt( + duration_in_secs=duration_in_secs, + pretrained_cap_model_path=os.path.join( + self.model_dir, 'bmt/sample/best_cap_model.pt'), + prop_generator_model_path=os.path.join( + self.model_dir, 'bmt/sample/best_prop_model.pt'), + features={ + 'rgb': rgb, + 'flow': flow, + 'audio': audio + }, + device_id=0) + sentences = [item['sentence'] for item in txt] + + clip_image_features = [] + for frame in frames: + x = self.clip_preprocess( + Image.fromarray(cv2.cvtColor( + frame, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(self.device) + with torch.no_grad(): + f = self.clip_model.encode_image(x).squeeze(0).cpu().numpy() + clip_image_features.append(f) + + clip_txt_features = [] + for sentence in sentences: + text_input = clip.tokenize(sentence).to(self.device) + with torch.no_grad(): + text_feature = self.clip_model.encode_text(text_input).squeeze( + 0).cpu().numpy() + clip_txt_features.append(text_feature) + clip_txt_features = self.sample_txt_feateures(clip_txt_features) + clip_txt_features = np.array(clip_txt_features).reshape((1, -1)) + + result = { + 'video_name': video_path, + 'clip_image_features': np.array(clip_image_features), + 'clip_txt_features': np.array(clip_txt_features), + 'n_frame': n_frame, + 'picks': np.array(picks) + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + clip_image_features = input['clip_image_features'] + clip_txt_features = input['clip_txt_features'] + clip_image_features = self.norm_feature(clip_image_features) + clip_txt_features = self.norm_feature(clip_txt_features) + + change_points, n_frame_per_seg = get_change_points( + clip_image_features, input['n_frame']) + + summary = self.inference(clip_image_features, clip_txt_features, + input['n_frame'], input['picks'], + change_points) + + output = summary_format(summary, self.fps) + + return {OutputKeys.OUTPUT: output} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + return inputs + + def inference(self, clip_image_features, clip_txt_features, n_frames, + picks, change_points): + clip_image_features = torch.from_numpy( + np.array(clip_image_features, np.float32)).unsqueeze(0) + clip_txt_features = torch.from_numpy( + np.array(clip_txt_features, np.float32)).unsqueeze(0) + picks = np.array(picks, np.int32) + + with torch.no_grad(): + results = self.clipit_model( + dict( + frame_features=clip_image_features, + txt_features=clip_txt_features)) + scores = results['scores'] + if not scores.device.type == 'cpu': + scores = scores.cpu() + scores = scores.squeeze(0).numpy().tolist() + summary = generate_summary([change_points], [scores], [n_frames], + [picks])[0] + + return summary.tolist() + + def sample_txt_feateures(self, feat, num=7): + while len(feat) < num: + feat.append(feat[-1]) + idxes = list(np.arange(0, len(feat))) + samples_idx = [] + for ii in range(num): + idx = random.choice(idxes) + while idx in samples_idx: + idx = random.choice(idxes) + samples_idx.append(idx) + samples_idx.sort() + + samples = [] + for idx in samples_idx: + samples.append(feat[idx]) + return samples + + def norm_feature(self, frames_feat): + for ii in range(len(frames_feat)): + frame_feat = frames_feat[ii] + frames_feat[ii] = frame_feat / np.linalg.norm(frame_feat) + frames_feat = frames_feat.reshape((frames_feat.shape[0], -1)) + return frames_feat diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index f0a97dbd..b1bccc4c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -80,6 +80,7 @@ class CVTasks(object): video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' movie_scene_segmentation = 'movie-scene-segmentation' + language_guided_video_summarization = 'language-guided-video-summarization' # video segmentation referring_video_object_segmentation = 'referring-video-object-segmentation' diff --git a/requirements/cv.txt b/requirements/cv.txt index e0444f1d..43eba7f9 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,5 +1,7 @@ albumentations>=1.0.3 av>=9.2.0 +bmt_clipit>=1.0 +clip>=1.0 easydict fairscale>=0.4.1 fastai>=1.0.51 @@ -33,3 +35,4 @@ tf_slim timm>=0.4.9 torchmetrics>=0.6.2 torchvision +videofeatures_clipit>=1.0 diff --git a/tests/pipelines/test_language_guided_video_summarization.py b/tests/pipelines/test_language_guided_video_summarization.py new file mode 100755 index 00000000..0f06d4f2 --- /dev/null +++ b/tests/pipelines/test_language_guided_video_summarization.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class LanguageGuidedVideoSummarizationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.language_guided_video_summarization + self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_path = 'data/test/videos/video_category_test_video.mp4' + # input can be sentences such as sentences=['phone', 'hand'], or sentences=None + sentences = None + summarization_pipeline = pipeline( + Tasks.language_guided_video_summarization, model=self.model_id) + result = summarization_pipeline((video_path, sentences)) + + print(f'video summarization output: \n{result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_path = 'data/test/videos/video_category_test_video.mp4' + summarization_pipeline = pipeline( + Tasks.language_guided_video_summarization) + result = summarization_pipeline(video_path) + + print(f'video summarization output:\n {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()