Browse Source

add support for cv/language_guided_video_summarization

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10636269
master
james.wjg yingda.chen 3 years ago
parent
commit
541e460f8b
14 changed files with 944 additions and 4 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +4
    -4
      modelscope/models/cv/__init__.py
  3. +25
    -0
      modelscope/models/cv/language_guided_video_summarization/__init__.py
  4. +194
    -0
      modelscope/models/cv/language_guided_video_summarization/summarizer.py
  5. +25
    -0
      modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py
  6. +48
    -0
      modelscope/models/cv/language_guided_video_summarization/transformer/layers.py
  7. +229
    -0
      modelscope/models/cv/language_guided_video_summarization/transformer/models.py
  8. +27
    -0
      modelscope/models/cv/language_guided_video_summarization/transformer/modules.py
  9. +83
    -0
      modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py
  10. +4
    -0
      modelscope/pipelines/cv/__init__.py
  11. +250
    -0
      modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py
  12. +1
    -0
      modelscope/utils/constant.py
  13. +3
    -0
      requirements/cv.txt
  14. +49
    -0
      tests/pipelines/test_language_guided_video_summarization.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -32,6 +32,7 @@ class Models(object):
image_reid_person = 'passvitb' image_reid_person = 'passvitb'
image_inpainting = 'FFTInpainting' image_inpainting = 'FFTInpainting'
video_summarization = 'pgl-video-summarization' video_summarization = 'pgl-video-summarization'
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation' swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
text_driven_segmentation = 'text-driven-segmentation' text_driven_segmentation = 'text-driven-segmentation'
@@ -200,6 +201,7 @@ class Pipelines(object):
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
image_panoptic_segmentation = 'image-panoptic-segmentation' image_panoptic_segmentation = 'image-panoptic-segmentation'
video_summarization = 'googlenet_pgl_video_summarization' video_summarization = 'googlenet_pgl_video_summarization'
language_guided_video_summarization = 'clip-it-video-summarization'
image_semantic_segmentation = 'image-semantic-segmentation' image_semantic_segmentation = 'image-semantic-segmentation'
image_reid_person = 'passvitb-image-reid-person' image_reid_person = 'passvitb-image-reid-person'
image_inpainting = 'fft-inpainting' image_inpainting = 'fft-inpainting'


+ 4
- 4
modelscope/models/cv/__init__.py View File

@@ -10,10 +10,10 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_panoptic_segmentation, image_portrait_enhancement, image_panoptic_segmentation, image_portrait_enhancement,
image_reid_person, image_semantic_segmentation, image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation, image_to_image_generation, image_to_image_translation,
movie_scene_segmentation, object_detection,
product_retrieval_embedding, realtime_object_detection,
referring_video_object_segmentation, salient_detection,
shop_segmentation, super_resolution,
language_guided_video_summarization, movie_scene_segmentation,
object_detection, product_retrieval_embedding,
realtime_object_detection, referring_video_object_segmentation,
salient_detection, shop_segmentation, super_resolution,
video_single_object_tracking, video_summarization, virual_tryon) video_single_object_tracking, video_summarization, virual_tryon)


# yapf: enable # yapf: enable

+ 25
- 0
modelscope/models/cv/language_guided_video_summarization/__init__.py View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .summarizer import (
ClipItVideoSummarization, )

else:
_import_structure = {
'summarizer': [
'ClipItVideoSummarization',
]
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 194
- 0
modelscope/models/cv/language_guided_video_summarization/summarizer.py View File

@@ -0,0 +1,194 @@
# Part of the implementation is borrowed and modified from BMT and video_features,
# publicly available at https://github.com/v-iashin/BMT
# and https://github.com/v-iashin/video_features

import argparse
import os
import os.path as osp
from copy import deepcopy
from typing import Dict, Union

import numpy as np
import torch
import torch.nn as nn
from bmt_clipit.sample.single_video_prediction import (caption_proposals,
generate_proposals,
load_cap_model,
load_prop_model)
from bmt_clipit.utilities.proposal_utils import non_max_suppresion
from torch.nn.parallel import DataParallel, DistributedDataParallel
from videofeatures_clipit.models.i3d.extract_i3d import ExtractI3D
from videofeatures_clipit.models.vggish.extract_vggish import ExtractVGGish
from videofeatures_clipit.utils.utils import (fix_tensorflow_gpu_allocation,
form_list_from_user_input)

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.language_guided_video_summarization.transformer import \
Transformer
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


def extract_text(args):
# Loading models and other essential stuff
cap_cfg, cap_model, train_dataset = load_cap_model(
args.pretrained_cap_model_path, args.device_id)
prop_cfg, prop_model = load_prop_model(args.device_id,
args.prop_generator_model_path,
args.pretrained_cap_model_path,
args.max_prop_per_vid)
# Proposal
proposals = generate_proposals(prop_model, args.features,
train_dataset.pad_idx, prop_cfg,
args.device_id, args.duration_in_secs)
# NMS if specified
if args.nms_tiou_thresh is not None:
proposals = non_max_suppresion(proposals.squeeze(),
args.nms_tiou_thresh)
proposals = proposals.unsqueeze(0)
# Captions for each proposal
captions = caption_proposals(cap_model, args.features, train_dataset,
cap_cfg, args.device_id, proposals,
args.duration_in_secs)
return captions


def extract_video_features(video_path, tmp_path, feature_type, i3d_flow_path,
i3d_rgb_path, kinetics_class_labels, pwc_path,
vggish_model_path, vggish_pca_path, extraction_fps,
device):
default_args = dict(
device=device,
extraction_fps=extraction_fps,
feature_type=feature_type,
file_with_video_paths=None,
i3d_flow_path=i3d_flow_path,
i3d_rgb_path=i3d_rgb_path,
keep_frames=False,
kinetics_class_labels=kinetics_class_labels,
min_side_size=256,
pwc_path=pwc_path,
show_kinetics_pred=False,
stack_size=64,
step_size=64,
tmp_path=tmp_path,
vggish_model_path=vggish_model_path,
vggish_pca_path=vggish_pca_path,
)
args = argparse.Namespace(**default_args)

if args.feature_type == 'i3d':
extractor = ExtractI3D(args)
elif args.feature_type == 'vggish':
extractor = ExtractVGGish(args)

feats = extractor(video_path)
return feats


def video_features_to_txt(duration_in_secs, pretrained_cap_model_path,
prop_generator_model_path, features, device_id):
default_args = dict(
device_id=device_id,
duration_in_secs=duration_in_secs,
features=features,
pretrained_cap_model_path=pretrained_cap_model_path,
prop_generator_model_path=prop_generator_model_path,
max_prop_per_vid=100,
nms_tiou_thresh=0.4,
)
args = argparse.Namespace(**default_args)
txt = extract_text(args)
return txt


@MODELS.register_module(
Tasks.language_guided_video_summarization,
module_name=Models.language_guided_video_summarization)
class ClipItVideoSummarization(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the video summarization model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)

self.loss = nn.MSELoss()
self.model = Transformer()
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model = self.model.to(self._device)

self.model = self.load_pretrained(self.model, model_path)

if self.training:
self.model.train()
else:
self.model.eval()

def load_pretrained(self, net, load_path, strict=True, param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(
load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info(
f'Loading: {param_key} does not exist, use params.')
if param_key in load_net:
load_net = load_net[param_key]
logger.info(
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
logger.info('load model done.')
return net

def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
frame_features = input['frame_features']
txt_features = input['txt_features']
gtscore = input['gtscore']
preds, attn_weights = self.model(frame_features, txt_features,
frame_features)
return {'loss': self.loss(preds, gtscore)}

def _inference_forward(self, input: Dict[str,
Tensor]) -> Dict[str, Tensor]:
frame_features = input['frame_features']
txt_features = input['txt_features']
y, dec_output = self.model(frame_features, txt_features,
frame_features)
return {'scores': y}

def forward(self, input: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Union[list, Tensor]]: results
"""
for key, value in input.items():
input[key] = input[key].to(self._device)
if self.training:
return self._train_forward(input)
else:
return self._inference_forward(input)

+ 25
- 0
modelscope/models/cv/language_guided_video_summarization/transformer/__init__.py View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .models import (
Transformer, )

else:
_import_structure = {
'models': [
'Transformer',
]
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 48
- 0
modelscope/models/cv/language_guided_video_summarization/transformer/layers.py View File

@@ -0,0 +1,48 @@
# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch,
# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch
import torch
import torch.nn as nn

from .sub_layers import MultiHeadAttention, PositionwiseFeedForward


class EncoderLayer(nn.Module):
"""Compose with two layers"""

def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)

def forward(self, enc_input, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn


class DecoderLayer(nn.Module):
"""Compose with three layers"""

def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)

def forward(self,
dec_input,
enc_output,
slf_attn_mask=None,
dec_enc_attn_mask=None):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn

+ 229
- 0
modelscope/models/cv/language_guided_video_summarization/transformer/models.py View File

@@ -0,0 +1,229 @@
# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch,
# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch

import numpy as np
import torch
import torch.nn as nn

from .layers import DecoderLayer, EncoderLayer
from .sub_layers import MultiHeadAttention


class PositionalEncoding(nn.Module):

def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()

# Not a parameter
self.register_buffer(
'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table"""

# TODO: make it with torch instead of numpy

def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]

sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

return torch.FloatTensor(sinusoid_table).unsqueeze(0)

def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()


class Encoder(nn.Module):
"""A encoder model with self attention mechanism."""

def __init__(self,
d_word_vec=1024,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=2048,
dropout=0.1,
n_position=200):

super().__init__()

self.position_enc = PositionalEncoding(
d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.d_model = d_model

def forward(self, enc_output, return_attns=False):

enc_slf_attn_list = []
# -- Forward
enc_output = self.dropout(self.position_enc(enc_output))
enc_output = self.layer_norm(enc_output)

for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output)
enc_slf_attn_list += [enc_slf_attn] if return_attns else []

if return_attns:
return enc_output, enc_slf_attn_list
return enc_output,


class Decoder(nn.Module):
"""A decoder model with self attention mechanism."""

def __init__(self,
d_word_vec=1024,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=2048,
n_position=200,
dropout=0.1):

super().__init__()

self.position_enc = PositionalEncoding(
d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.d_model = d_model

def forward(self,
dec_output,
enc_output,
src_mask=None,
trg_mask=None,
return_attns=False):

dec_slf_attn_list, dec_enc_attn_list = [], []

# -- Forward
dec_output = self.dropout(self.position_enc(dec_output))
dec_output = self.layer_norm(dec_output)

for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output,
enc_output,
slf_attn_mask=trg_mask,
dec_enc_attn_mask=src_mask)
dec_slf_attn_list += [dec_slf_attn] if return_attns else []
dec_enc_attn_list += [dec_enc_attn] if return_attns else []

if return_attns:
return dec_output, dec_slf_attn_list, dec_enc_attn_list
return dec_output,


class Transformer(nn.Module):
"""A sequence to sequence model with attention mechanism."""

def __init__(self,
num_sentence=7,
txt_atten_head=4,
d_frame_vec=512,
d_model=512,
d_inner=2048,
n_layers=6,
n_head=8,
d_k=256,
d_v=256,
dropout=0.1,
n_position=4000):

super().__init__()

self.d_model = d_model

self.layer_norm_img_src = nn.LayerNorm(d_frame_vec, eps=1e-6)
self.layer_norm_img_trg = nn.LayerNorm(d_frame_vec, eps=1e-6)
self.layer_norm_txt = nn.LayerNorm(
num_sentence * d_frame_vec, eps=1e-6)

self.linear_txt = nn.Linear(
in_features=num_sentence * d_frame_vec, out_features=d_model)
self.lg_attention = MultiHeadAttention(
n_head=txt_atten_head, d_model=d_model, d_k=d_k, d_v=d_v)

self.encoder = Encoder(
n_position=n_position,
d_word_vec=d_frame_vec,
d_model=d_model,
d_inner=d_inner,
n_layers=n_layers,
n_head=n_head,
d_k=d_k,
d_v=d_v,
dropout=dropout)

self.decoder = Decoder(
n_position=n_position,
d_word_vec=d_frame_vec,
d_model=d_model,
d_inner=d_inner,
n_layers=n_layers,
n_head=n_head,
d_k=d_k,
d_v=d_v,
dropout=dropout)

for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

assert d_model == d_frame_vec, 'the dimensions of all module outputs shall be the same.'

self.linear_1 = nn.Linear(in_features=d_model, out_features=d_model)
self.linear_2 = nn.Linear(
in_features=self.linear_1.out_features, out_features=1)

self.drop = nn.Dropout(p=0.5)
self.norm_y = nn.LayerNorm(normalized_shape=d_model, eps=1e-6)
self.norm_linear = nn.LayerNorm(
normalized_shape=self.linear_1.out_features, eps=1e-6)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def forward(self, src_seq, src_txt, trg_seq):

features_txt = self.linear_txt(src_txt)
atten_seq, txt_attn = self.lg_attention(src_seq, features_txt,
features_txt)

enc_output, *_ = self.encoder(atten_seq)
dec_output, *_ = self.decoder(trg_seq, enc_output)

y = self.drop(enc_output)
y = self.norm_y(y)

# 2-layer NN (Regressor Network)
y = self.linear_1(y)
y = self.relu(y)
y = self.drop(y)
y = self.norm_linear(y)

y = self.linear_2(y)
y = self.sigmoid(y)
y = y.view(1, -1)

return y, dec_output

+ 27
- 0
modelscope/models/cv/language_guided_video_summarization/transformer/modules.py View File

@@ -0,0 +1,27 @@
# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch,
# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F


class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention"""

def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)

def forward(self, q, k, v, mask=None):

attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)

attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)

return output, attn

+ 83
- 0
modelscope/models/cv/language_guided_video_summarization/transformer/sub_layers.py View File

@@ -0,0 +1,83 @@
# Part of the implementation is borrowed and modified from attention-is-all-you-need-pytorch,
# publicly available at https://github.com/jadore801120/attention-is-all-you-need-pytorch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from .modules import ScaledDotProductAttention


class MultiHeadAttention(nn.Module):
"""Multi-Head Attention module"""

def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()

self.n_head = n_head
self.d_k = d_k
self.d_v = d_v

self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

self.attention = ScaledDotProductAttention(temperature=d_k**0.5)

self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

def forward(self, q, k, v, mask=None):

d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

residual = q

# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

# Transpose for attention dot product: b x n x lq x dv
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.

q, attn = self.attention(q, k, v, mask=mask)

# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
q += residual

q = self.layer_norm(q)

return q, attn


class PositionwiseFeedForward(nn.Module):
"""A two-feed-forward-layer module"""

def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout = nn.Dropout(dropout)

def forward(self, x):

residual = x

x = self.w_2(F.relu(self.w_1(x)))
x = self.dropout(x)
x += residual

x = self.layer_norm(x)

return x

+ 4
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -59,6 +59,7 @@ if TYPE_CHECKING:
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin
from .hand_static_pipeline import HandStaticPipeline from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline


else: else:
_import_structure = { _import_structure = {
@@ -132,6 +133,9 @@ else:
'referring_video_object_segmentation_pipeline': [ 'referring_video_object_segmentation_pipeline': [
'ReferringVideoObjectSegmentationPipeline' 'ReferringVideoObjectSegmentationPipeline'
], ],
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
]
} }


import sys import sys


+ 250
- 0
modelscope/pipelines/cv/language_guided_video_summarization_pipeline.py View File

@@ -0,0 +1,250 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import os.path as osp
import random
import shutil
import tempfile
from typing import Any, Dict

import clip
import cv2
import numpy as np
import torch
from PIL import Image

from modelscope.metainfo import Pipelines
from modelscope.models.cv.language_guided_video_summarization import \
ClipItVideoSummarization
from modelscope.models.cv.language_guided_video_summarization.summarizer import (
extract_video_features, video_features_to_txt)
from modelscope.models.cv.video_summarization import summary_format
from modelscope.models.cv.video_summarization.summarizer import (
generate_summary, get_change_points)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.language_guided_video_summarization,
module_name=Pipelines.language_guided_video_summarization)
class LanguageGuidedVideoSummarizationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a language guided video summarization pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)
logger.info(f'loading model from {model}')
self.model_dir = model

self.tmp_dir = kwargs.get('tmp_dir', None)
if self.tmp_dir is None:
self.tmp_dir = tempfile.TemporaryDirectory().name

config_path = osp.join(model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)

self.clip_model, self.clip_preprocess = clip.load(
'ViT-B/32',
device=self.device,
download_root=os.path.join(self.model_dir, 'clip'))

self.clipit_model = ClipItVideoSummarization(model)
self.clipit_model = self.clipit_model.to(self.device).eval()

logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if not isinstance(input, tuple):
raise TypeError(f'input should be a str,'
f' but got {type(input)}')

video_path, sentences = input

if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

frames = []
picks = []
cap = cv2.VideoCapture(video_path)
self.fps = cap.get(cv2.CAP_PROP_FPS)
self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
frame_idx = 0
# extract 1 frame every 15 frames in the video and save the frame index
while (cap.isOpened()):
ret, frame = cap.read()
if not ret:
break
if frame_idx % 15 == 0:
frames.append(frame)
picks.append(frame_idx)
frame_idx += 1
n_frame = frame_idx

if sentences is None:
logger.info('input sentences is none, using sentences from video!')

tmp_path = os.path.join(self.tmp_dir, 'tmp')
i3d_flow_path = os.path.join(self.model_dir, 'i3d/i3d_flow.pt')
i3d_rgb_path = os.path.join(self.model_dir, 'i3d/i3d_rgb.pt')
kinetics_class_labels = os.path.join(self.model_dir,
'i3d/label_map.txt')
pwc_path = os.path.join(self.model_dir, 'i3d/pwc_net.pt')
vggish_model_path = os.path.join(self.model_dir,
'vggish/vggish_model.ckpt')
vggish_pca_path = os.path.join(self.model_dir,
'vggish/vggish_pca_params.npz')

device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
i3d_feats = extract_video_features(
video_path=video_path,
feature_type='i3d',
tmp_path=tmp_path,
i3d_flow_path=i3d_flow_path,
i3d_rgb_path=i3d_rgb_path,
kinetics_class_labels=kinetics_class_labels,
pwc_path=pwc_path,
vggish_model_path=vggish_model_path,
vggish_pca_path=vggish_pca_path,
extraction_fps=2,
device=device)
rgb = i3d_feats['rgb']
flow = i3d_feats['flow']

device = '/gpu:0' if torch.cuda.is_available() else '/cpu:0'
vggish = extract_video_features(
video_path=video_path,
feature_type='vggish',
tmp_path=tmp_path,
i3d_flow_path=i3d_flow_path,
i3d_rgb_path=i3d_rgb_path,
kinetics_class_labels=kinetics_class_labels,
pwc_path=pwc_path,
vggish_model_path=vggish_model_path,
vggish_pca_path=vggish_pca_path,
extraction_fps=2,
device=device)
audio = vggish['audio']

duration_in_secs = float(self.frame_count) / self.fps

txt = video_features_to_txt(
duration_in_secs=duration_in_secs,
pretrained_cap_model_path=os.path.join(
self.model_dir, 'bmt/sample/best_cap_model.pt'),
prop_generator_model_path=os.path.join(
self.model_dir, 'bmt/sample/best_prop_model.pt'),
features={
'rgb': rgb,
'flow': flow,
'audio': audio
},
device_id=0)
sentences = [item['sentence'] for item in txt]

clip_image_features = []
for frame in frames:
x = self.clip_preprocess(
Image.fromarray(cv2.cvtColor(
frame, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(self.device)
with torch.no_grad():
f = self.clip_model.encode_image(x).squeeze(0).cpu().numpy()
clip_image_features.append(f)

clip_txt_features = []
for sentence in sentences:
text_input = clip.tokenize(sentence).to(self.device)
with torch.no_grad():
text_feature = self.clip_model.encode_text(text_input).squeeze(
0).cpu().numpy()
clip_txt_features.append(text_feature)
clip_txt_features = self.sample_txt_feateures(clip_txt_features)
clip_txt_features = np.array(clip_txt_features).reshape((1, -1))

result = {
'video_name': video_path,
'clip_image_features': np.array(clip_image_features),
'clip_txt_features': np.array(clip_txt_features),
'n_frame': n_frame,
'picks': np.array(picks)
}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
clip_image_features = input['clip_image_features']
clip_txt_features = input['clip_txt_features']
clip_image_features = self.norm_feature(clip_image_features)
clip_txt_features = self.norm_feature(clip_txt_features)

change_points, n_frame_per_seg = get_change_points(
clip_image_features, input['n_frame'])

summary = self.inference(clip_image_features, clip_txt_features,
input['n_frame'], input['picks'],
change_points)

output = summary_format(summary, self.fps)

return {OutputKeys.OUTPUT: output}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)
return inputs

def inference(self, clip_image_features, clip_txt_features, n_frames,
picks, change_points):
clip_image_features = torch.from_numpy(
np.array(clip_image_features, np.float32)).unsqueeze(0)
clip_txt_features = torch.from_numpy(
np.array(clip_txt_features, np.float32)).unsqueeze(0)
picks = np.array(picks, np.int32)

with torch.no_grad():
results = self.clipit_model(
dict(
frame_features=clip_image_features,
txt_features=clip_txt_features))
scores = results['scores']
if not scores.device.type == 'cpu':
scores = scores.cpu()
scores = scores.squeeze(0).numpy().tolist()
summary = generate_summary([change_points], [scores], [n_frames],
[picks])[0]

return summary.tolist()

def sample_txt_feateures(self, feat, num=7):
while len(feat) < num:
feat.append(feat[-1])
idxes = list(np.arange(0, len(feat)))
samples_idx = []
for ii in range(num):
idx = random.choice(idxes)
while idx in samples_idx:
idx = random.choice(idxes)
samples_idx.append(idx)
samples_idx.sort()

samples = []
for idx in samples_idx:
samples.append(feat[idx])
return samples

def norm_feature(self, frames_feat):
for ii in range(len(frames_feat)):
frame_feat = frames_feat[ii]
frames_feat[ii] = frame_feat / np.linalg.norm(frame_feat)
frames_feat = frames_feat.reshape((frames_feat.shape[0], -1))
return frames_feat

+ 1
- 0
modelscope/utils/constant.py View File

@@ -80,6 +80,7 @@ class CVTasks(object):
video_embedding = 'video-embedding' video_embedding = 'video-embedding'
virtual_try_on = 'virtual-try-on' virtual_try_on = 'virtual-try-on'
movie_scene_segmentation = 'movie-scene-segmentation' movie_scene_segmentation = 'movie-scene-segmentation'
language_guided_video_summarization = 'language-guided-video-summarization'


# video segmentation # video segmentation
referring_video_object_segmentation = 'referring-video-object-segmentation' referring_video_object_segmentation = 'referring-video-object-segmentation'


+ 3
- 0
requirements/cv.txt View File

@@ -1,5 +1,7 @@
albumentations>=1.0.3 albumentations>=1.0.3
av>=9.2.0 av>=9.2.0
bmt_clipit>=1.0
clip>=1.0
easydict easydict
fairscale>=0.4.1 fairscale>=0.4.1
fastai>=1.0.51 fastai>=1.0.51
@@ -33,3 +35,4 @@ tf_slim
timm>=0.4.9 timm>=0.4.9
torchmetrics>=0.6.2 torchmetrics>=0.6.2
torchvision torchvision
videofeatures_clipit>=1.0

+ 49
- 0
tests/pipelines/test_language_guided_video_summarization.py View File

@@ -0,0 +1,49 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import shutil
import tempfile
import unittest

import torch

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class LanguageGuidedVideoSummarizationTest(unittest.TestCase,
DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.language_guided_video_summarization
self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
video_path = 'data/test/videos/video_category_test_video.mp4'
# input can be sentences such as sentences=['phone', 'hand'], or sentences=None
sentences = None
summarization_pipeline = pipeline(
Tasks.language_guided_video_summarization, model=self.model_id)
result = summarization_pipeline((video_path, sentences))

print(f'video summarization output: \n{result}.')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
video_path = 'data/test/videos/video_category_test_video.mp4'
summarization_pipeline = pipeline(
Tasks.language_guided_video_summarization)
result = summarization_pipeline(video_path)

print(f'video summarization output:\n {result}.')

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save