新增video summarization模型的inference和finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9840532
master
| @@ -21,6 +21,7 @@ class Models(object): | |||||
| body_2d_keypoints = 'body-2d-keypoints' | body_2d_keypoints = 'body-2d-keypoints' | ||||
| crowd_counting = 'HRNetCrowdCounting' | crowd_counting = 'HRNetCrowdCounting' | ||||
| image_reid_person = 'passvitb' | image_reid_person = 'passvitb' | ||||
| video_summarization = 'pgl-video-summarization' | |||||
| # nlp models | # nlp models | ||||
| bert = 'bert' | bert = 'bert' | ||||
| @@ -113,6 +114,7 @@ class Pipelines(object): | |||||
| tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
| crowd_counting = 'hrnet-crowd-counting' | crowd_counting = 'hrnet-crowd-counting' | ||||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | ||||
| video_summarization = 'googlenet_pgl_video_summarization' | |||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| # nlp tasks | # nlp tasks | ||||
| @@ -170,6 +172,7 @@ class Trainers(object): | |||||
| # cv trainers | # cv trainers | ||||
| image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
| image_portrait_enhancement = 'image-portrait-enhancement' | image_portrait_enhancement = 'image-portrait-enhancement' | ||||
| video_summarization = 'video-summarization' | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| @@ -194,6 +197,7 @@ class Preprocessors(object): | |||||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | ||||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | ||||
| image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | ||||
| video_summarization_preprocessor = 'video-summarization-preprocessor' | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
| @@ -246,6 +250,7 @@ class Metrics(object): | |||||
| image_color_enhance_metric = 'image-color-enhance-metric' | image_color_enhance_metric = 'image-color-enhance-metric' | ||||
| # metrics for image-portrait-enhancement task | # metrics for image-portrait-enhancement task | ||||
| image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' | image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' | ||||
| video_summarization_metric = 'video-summarization-metric' | |||||
| class Optimizers(object): | class Optimizers(object): | ||||
| @@ -14,6 +14,7 @@ if TYPE_CHECKING: | |||||
| from .sequence_classification_metric import SequenceClassificationMetric | from .sequence_classification_metric import SequenceClassificationMetric | ||||
| from .text_generation_metric import TextGenerationMetric | from .text_generation_metric import TextGenerationMetric | ||||
| from .token_classification_metric import TokenClassificationMetric | from .token_classification_metric import TokenClassificationMetric | ||||
| from .video_summarization_metric import VideoSummarizationMetric | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -28,6 +29,7 @@ else: | |||||
| 'sequence_classification_metric': ['SequenceClassificationMetric'], | 'sequence_classification_metric': ['SequenceClassificationMetric'], | ||||
| 'text_generation_metric': ['TextGenerationMetric'], | 'text_generation_metric': ['TextGenerationMetric'], | ||||
| 'token_classification_metric': ['TokenClassificationMetric'], | 'token_classification_metric': ['TokenClassificationMetric'], | ||||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -15,6 +15,7 @@ class MetricKeys(object): | |||||
| RECALL = 'recall' | RECALL = 'recall' | ||||
| PSNR = 'psnr' | PSNR = 'psnr' | ||||
| SSIM = 'ssim' | SSIM = 'ssim' | ||||
| FScore = 'fscore' | |||||
| task_default_metrics = { | task_default_metrics = { | ||||
| @@ -28,6 +29,7 @@ task_default_metrics = { | |||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | ||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| [Metrics.image_portrait_enhancement_metric], | [Metrics.image_portrait_enhancement_metric], | ||||
| Tasks.video_summarization: [Metrics.video_summarization_metric], | |||||
| } | } | ||||
| @@ -0,0 +1,78 @@ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.models.cv.video_summarization.summarizer import \ | |||||
| generate_summary | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| def evaluate_summary(predicted_summary, user_summary, eval_method): | |||||
| """ Compare the predicted summary with the user defined one(s). | |||||
| :param ndarray predicted_summary: The generated summary from our model. | |||||
| :param ndarray user_summary: The user defined ground truth summaries (or summary). | |||||
| :param str eval_method: The proposed evaluation method; either 'max' (SumMe) or 'avg' (TVSum). | |||||
| :return: The reduced fscore based on the eval_method | |||||
| """ | |||||
| max_len = max(len(predicted_summary), user_summary.shape[1]) | |||||
| S = np.zeros(max_len, dtype=int) | |||||
| G = np.zeros(max_len, dtype=int) | |||||
| S[:len(predicted_summary)] = predicted_summary | |||||
| f_scores = [] | |||||
| for user in range(user_summary.shape[0]): | |||||
| G[:user_summary.shape[1]] = user_summary[user] | |||||
| overlapped = S & G | |||||
| # Compute precision, recall, f-score | |||||
| precision = sum(overlapped) / sum(S) | |||||
| recall = sum(overlapped) / sum(G) | |||||
| if precision + recall == 0: | |||||
| f_scores.append(0) | |||||
| else: | |||||
| f_score = 2 * precision * recall * 100 / (precision + recall) | |||||
| f_scores.append(f_score) | |||||
| if eval_method == 'max': | |||||
| return max(f_scores) | |||||
| else: | |||||
| return sum(f_scores) / len(f_scores) | |||||
| def calculate_f_score(outputs: Dict, inputs: Dict): | |||||
| scores = outputs['scores'] | |||||
| scores = scores.squeeze(0).cpu().numpy().tolist() | |||||
| user_summary = inputs['user_summary'].cpu().numpy()[0] | |||||
| sb = inputs['change_points'].cpu().numpy()[0] | |||||
| n_frames = inputs['n_frames'].cpu().numpy()[0] | |||||
| positions = inputs['positions'].cpu().numpy()[0] | |||||
| summary = generate_summary([sb], [scores], [n_frames], [positions])[0] | |||||
| f_score = evaluate_summary(summary, user_summary, 'avg') | |||||
| return f_score | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.video_summarization_metric) | |||||
| class VideoSummarizationMetric(Metric): | |||||
| """The metric for video summarization task. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.inputs = [] | |||||
| self.outputs = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| self.outputs.append(outputs) | |||||
| self.inputs.append(inputs) | |||||
| def evaluate(self): | |||||
| f_scores = [ | |||||
| calculate_f_score(output, input) | |||||
| for output, input in zip(self.outputs, self.inputs) | |||||
| ] | |||||
| return {MetricKeys.FScore: sum(f_scores) / len(f_scores)} | |||||
| @@ -7,4 +7,4 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| image_to_image_generation, image_to_image_translation, | image_to_image_generation, image_to_image_translation, | ||||
| object_detection, product_retrieval_embedding, | object_detection, product_retrieval_embedding, | ||||
| salient_detection, super_resolution, | salient_detection, super_resolution, | ||||
| video_single_object_tracking, virual_tryon) | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | |||||
| @@ -0,0 +1 @@ | |||||
| from .summarizer import PGLVideoSummarization | |||||
| @@ -0,0 +1,118 @@ | |||||
| # The implementation is based on pytorch-caffe-models, available at https://github.com/crowsonkb/pytorch-caffe-models. | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class Inception(nn.Module): | |||||
| def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, | |||||
| pool_proj): | |||||
| super().__init__() | |||||
| self.conv_1x1 = nn.Conv2d(in_channels, ch1x1, 1) | |||||
| self.relu_1x1 = nn.ReLU(inplace=True) | |||||
| self.conv_3x3_reduce = nn.Conv2d(in_channels, ch3x3red, 1) | |||||
| self.relu_3x3_reduce = nn.ReLU(inplace=True) | |||||
| self.conv_3x3 = nn.Conv2d(ch3x3red, ch3x3, 3, padding=1) | |||||
| self.relu_3x3 = nn.ReLU(inplace=True) | |||||
| self.conv_5x5_reduce = nn.Conv2d(in_channels, ch5x5red, 1) | |||||
| self.relu_5x5_reduce = nn.ReLU(inplace=True) | |||||
| self.conv_5x5 = nn.Conv2d(ch5x5red, ch5x5, 5, padding=2) | |||||
| self.relu_5x5 = nn.ReLU(inplace=True) | |||||
| self.pool = nn.MaxPool2d(3, stride=1, padding=1) | |||||
| self.pool_proj = nn.Conv2d(in_channels, pool_proj, 1) | |||||
| self.relu_pool_proj = nn.ReLU(inplace=True) | |||||
| def forward(self, x): | |||||
| branch_1 = self.relu_1x1(self.conv_1x1(x)) | |||||
| branch_2 = self.relu_3x3_reduce(self.conv_3x3_reduce(x)) | |||||
| branch_2 = self.relu_3x3(self.conv_3x3(branch_2)) | |||||
| branch_3 = self.relu_5x5_reduce(self.conv_5x5_reduce(x)) | |||||
| branch_3 = self.relu_5x5(self.conv_5x5(branch_3)) | |||||
| branch_4 = self.pool(x) | |||||
| branch_4 = self.relu_pool_proj(self.pool_proj(branch_4)) | |||||
| return torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1) | |||||
| class GoogLeNet(nn.Sequential): | |||||
| def __init__(self, num_classes=1000): | |||||
| super().__init__() | |||||
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) | |||||
| self.relu1 = nn.ReLU(inplace=True) | |||||
| self.pool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) | |||||
| self.norm1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) | |||||
| self.conv2_reduce = nn.Conv2d(64, 64, kernel_size=1) | |||||
| self.relu2_reduce = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1) | |||||
| self.relu2 = nn.ReLU(inplace=True) | |||||
| self.norm2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) | |||||
| self.pool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) | |||||
| self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32) | |||||
| self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64) | |||||
| self.pool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) | |||||
| self.inception_4a = Inception(480, 192, 96, 208, 16, 48, 64) | |||||
| self.inception_4b = Inception(512, 160, 112, 224, 24, 64, 64) | |||||
| self.inception_4c = Inception(512, 128, 128, 256, 24, 64, 64) | |||||
| self.inception_4d = Inception(512, 112, 144, 288, 32, 64, 64) | |||||
| self.inception_4e = Inception(528, 256, 160, 320, 32, 128, 128) | |||||
| self.pool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) | |||||
| self.inception_5a = Inception(832, 256, 160, 320, 32, 128, 128) | |||||
| self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128) | |||||
| self.pool5 = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.loss3_classifier = nn.Linear(1024, num_classes) | |||||
| def forward(self, x): | |||||
| x = self.relu1(self.conv1(x)) | |||||
| x = self.pool1(x) | |||||
| x = self.norm1(x) | |||||
| x = self.relu2_reduce(self.conv2_reduce(x)) | |||||
| x = self.relu2(self.conv2(x)) | |||||
| x = self.norm2(x) | |||||
| x = self.pool2(x) | |||||
| x = self.inception_3a(x) | |||||
| x = self.inception_3b(x) | |||||
| x = self.pool3(x) | |||||
| x = self.inception_4a(x) | |||||
| x = self.inception_4b(x) | |||||
| x = self.inception_4c(x) | |||||
| x = self.inception_4d(x) | |||||
| x = self.inception_4e(x) | |||||
| x = self.pool4(x) | |||||
| x = self.inception_5a(x) | |||||
| x = self.inception_5b(x) | |||||
| x = self.pool5(x).flatten(1) | |||||
| return x | |||||
| class bvlc_googlenet(nn.Module): | |||||
| def __init__(self, input_size=224): | |||||
| """model for the BVLC GoogLeNet, trained on ImageNet. | |||||
| URL: https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet""" | |||||
| super(bvlc_googlenet, self).__init__() | |||||
| self.model = GoogLeNet(num_classes=1000) | |||||
| self.input_size = input_size | |||||
| self.input_mean = (104.0, 117.0, 123.0) | |||||
| def forward(self, frame): | |||||
| x = cv2.resize(frame, | |||||
| (self.input_size, self.input_size)).astype(np.float32) | |||||
| x = (x - self.input_mean).astype(np.float32) | |||||
| x = np.transpose(x, [2, 0, 1]) | |||||
| x = np.expand_dims(x, 0) | |||||
| x = torch.from_numpy(x) | |||||
| if not next(self.model.parameters()).device.type == 'cpu': | |||||
| x = x.cuda() | |||||
| with torch.no_grad(): | |||||
| frame_feat = self.model(x) | |||||
| if not frame_feat.device.type == 'cpu': | |||||
| frame_feat = frame_feat.cpu() | |||||
| frame_feat = frame_feat.numpy() | |||||
| frame_feat = frame_feat / np.linalg.norm(frame_feat) | |||||
| return frame_feat.reshape(-1) | |||||
| @@ -0,0 +1,35 @@ | |||||
| # The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. | |||||
| import numpy as np | |||||
| from .cpd_nonlin import cpd_nonlin | |||||
| def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): | |||||
| """Detect change points automatically selecting their number | |||||
| :param K: Kernel between each pair of frames in video | |||||
| :param ncp: Maximum number of change points | |||||
| :param vmax: Special parameter | |||||
| :param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x | |||||
| :param kwargs: Extra parameters for ``cpd_nonlin`` | |||||
| :return: Tuple (cps, costs) | |||||
| - cps - best selected change-points | |||||
| - costs - costs for 0,1,2,...,m change-points | |||||
| """ | |||||
| m = ncp | |||||
| _, scores = cpd_nonlin(K, m, backtrack=False, **kwargs) | |||||
| N = K.shape[0] | |||||
| N2 = N * desc_rate # length of the video before down-sampling | |||||
| penalties = np.zeros(m + 1) | |||||
| # Prevent division by zero (in case of 0 changes) | |||||
| ncp = np.arange(1, m + 1) | |||||
| penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1) | |||||
| costs = scores / float(N) + penalties | |||||
| m_best = np.argmin(costs) | |||||
| cps, scores2 = cpd_nonlin(K, m_best, **kwargs) | |||||
| return cps, scores2 | |||||
| @@ -0,0 +1,102 @@ | |||||
| # The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. | |||||
| import numpy as np | |||||
| def calc_scatters(K): | |||||
| """Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with | |||||
| starting frame i and ending frame j} | |||||
| """ | |||||
| n = K.shape[0] | |||||
| K1 = np.cumsum([0] + list(np.diag(K))) | |||||
| K2 = np.zeros((n + 1, n + 1)) | |||||
| # TODO: use the fact that K - symmetric | |||||
| K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) | |||||
| diagK2 = np.diag(K2) | |||||
| i = np.arange(n).reshape((-1, 1)) | |||||
| j = np.arange(n).reshape((1, -1)) | |||||
| ij_f32 = ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32)) | |||||
| diagK2_K2 = ( | |||||
| diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape( | |||||
| (-1, 1)) - K2[1:, :-1].T - K2[:-1, 1:]) | |||||
| scatters = ( | |||||
| K1[1:].reshape((1, -1)) - K1[:-1].reshape( | |||||
| (-1, 1)) - diagK2_K2 / ij_f32) | |||||
| scatters[j < i] = 0 | |||||
| return scatters | |||||
| def cpd_nonlin(K, | |||||
| ncp, | |||||
| lmin=1, | |||||
| lmax=100000, | |||||
| backtrack=True, | |||||
| verbose=True, | |||||
| out_scatters=None): | |||||
| """Change point detection with dynamic programming | |||||
| :param K: Square kernel matrix | |||||
| :param ncp: Number of change points to detect (ncp >= 0) | |||||
| :param lmin: Minimal length of a segment | |||||
| :param lmax: Maximal length of a segment | |||||
| :param backtrack: If False - only evaluate objective scores (to save memory) | |||||
| :param verbose: If true, print verbose message | |||||
| :param out_scatters: Output scatters | |||||
| :return: Tuple (cps, obj_vals) | |||||
| - cps - detected array of change points: mean is thought to be constant | |||||
| on [ cps[i], cps[i+1] ) | |||||
| - obj_vals - values of the objective function for 0..m changepoints | |||||
| """ | |||||
| m = int(ncp) # prevent numpy.int64 | |||||
| n, n1 = K.shape | |||||
| assert n == n1, 'Kernel matrix awaited.' | |||||
| assert (m + 1) * lmin <= n <= (m + 1) * lmax | |||||
| assert 1 <= lmin <= lmax | |||||
| if verbose: | |||||
| print('Precomputing scatters...') | |||||
| J = calc_scatters(K) | |||||
| if out_scatters is not None: | |||||
| out_scatters[0] = J | |||||
| if verbose: | |||||
| print('Inferring best change points...') | |||||
| # Iden[k, l] - value of the objective for k change-points and l first frames | |||||
| Iden = 1e101 * np.ones((m + 1, n + 1)) | |||||
| Iden[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] | |||||
| if backtrack: | |||||
| # p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l | |||||
| p = np.zeros((m + 1, n + 1), dtype=int) | |||||
| else: | |||||
| p = np.zeros((1, 1), dtype=int) | |||||
| for k in range(1, m + 1): | |||||
| for l_frame in range((k + 1) * lmin, n + 1): | |||||
| tmin = max(k * lmin, l_frame - lmax) | |||||
| tmax = l_frame - lmin + 1 | |||||
| c = J[tmin:tmax, l_frame - 1].reshape(-1) + \ | |||||
| Iden[k - 1, tmin:tmax].reshape(-1) | |||||
| Iden[k, l_frame] = np.min(c) | |||||
| if backtrack: | |||||
| p[k, l_frame] = np.argmin(c) + tmin | |||||
| # Collect change points | |||||
| cps = np.zeros(m, dtype=int) | |||||
| if backtrack: | |||||
| cur = n | |||||
| for k in range(m, 0, -1): | |||||
| cps[k - 1] = p[k, cur] | |||||
| cur = cps[k - 1] | |||||
| scores = Iden[:, n].copy() | |||||
| scores[scores > 1e99] = np.inf | |||||
| return cps, scores | |||||
| @@ -0,0 +1,311 @@ | |||||
| # The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class SelfAttention(nn.Module): | |||||
| def __init__(self, | |||||
| input_size=1024, | |||||
| output_size=1024, | |||||
| freq=10000, | |||||
| heads=1, | |||||
| pos_enc=None): | |||||
| """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V | |||||
| :param int input_size: Feature input size of Q, K, V. | |||||
| :param int output_size: Feature -hidden- size of Q, K, V. | |||||
| :param int freq: The frequency of the sinusoidal positional encoding. | |||||
| :param int heads: Number of heads for the attention module. | |||||
| :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative]. | |||||
| """ | |||||
| super(SelfAttention, self).__init__() | |||||
| self.permitted_encodings = ['absolute', 'relative'] | |||||
| if pos_enc is not None: | |||||
| pos_enc = pos_enc.lower() | |||||
| assert pos_enc in self.permitted_encodings, f'Supported encodings: {*self.permitted_encodings,}' | |||||
| self.input_size = input_size | |||||
| self.output_size = output_size | |||||
| self.heads = heads | |||||
| self.pos_enc = pos_enc | |||||
| self.freq = freq | |||||
| self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList( | |||||
| ), nn.ModuleList() | |||||
| for _ in range(self.heads): | |||||
| self.Wk.append( | |||||
| nn.Linear( | |||||
| in_features=input_size, | |||||
| out_features=output_size // heads, | |||||
| bias=False)) | |||||
| self.Wq.append( | |||||
| nn.Linear( | |||||
| in_features=input_size, | |||||
| out_features=output_size // heads, | |||||
| bias=False)) | |||||
| self.Wv.append( | |||||
| nn.Linear( | |||||
| in_features=input_size, | |||||
| out_features=output_size // heads, | |||||
| bias=False)) | |||||
| self.out = nn.Linear( | |||||
| in_features=output_size, out_features=input_size, bias=False) | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| self.drop = nn.Dropout(p=0.5) | |||||
| def getAbsolutePosition(self, T): | |||||
| """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame. | |||||
| Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762) | |||||
| :param int T: Number of frames contained in Q, K and V | |||||
| :return: Tensor with shape [T, T] | |||||
| """ | |||||
| freq = self.freq | |||||
| d = self.input_size | |||||
| pos = torch.tensor([k for k in range(T)], | |||||
| device=self.out.weight.device) | |||||
| i = torch.tensor([k for k in range(T // 2)], | |||||
| device=self.out.weight.device) | |||||
| # Reshape tensors each pos_k for each i indices | |||||
| pos = pos.reshape(pos.shape[0], 1) | |||||
| pos = pos.repeat_interleave(i.shape[0], dim=1) | |||||
| i = i.repeat(pos.shape[0], 1) | |||||
| AP = torch.zeros(T, T, device=self.out.weight.device) | |||||
| AP[pos, 2 * i] = torch.sin(pos / freq**((2 * i) / d)) | |||||
| AP[pos, 2 * i + 1] = torch.cos(pos / freq**((2 * i) / d)) | |||||
| return AP | |||||
| def getRelativePosition(self, T): | |||||
| """Calculate the sinusoidal positional encoding based on the relative position of each considered frame. | |||||
| r_pos calculations as here: https://theaisummer.com/positional-embeddings/ | |||||
| :param int T: Number of frames contained in Q, K and V | |||||
| :return: Tensor with shape [T, T] | |||||
| """ | |||||
| freq = self.freq | |||||
| d = 2 * T | |||||
| min_rpos = -(T - 1) | |||||
| i = torch.tensor([k for k in range(T)], device=self.out.weight.device) | |||||
| j = torch.tensor([k for k in range(T)], device=self.out.weight.device) | |||||
| # Reshape tensors each i for each j indices | |||||
| i = i.reshape(i.shape[0], 1) | |||||
| i = i.repeat_interleave(i.shape[0], dim=1) | |||||
| j = j.repeat(i.shape[0], 1) | |||||
| # Calculate the relative positions | |||||
| r_pos = j - i - min_rpos | |||||
| RP = torch.zeros(T, T, device=self.out.weight.device) | |||||
| idx = torch.tensor([k for k in range(T // 2)], | |||||
| device=self.out.weight.device) | |||||
| RP[:, 2 * idx] = torch.sin( | |||||
| r_pos[:, 2 * idx] / freq**((i[:, 2 * idx] + j[:, 2 * idx]) / d)) | |||||
| RP[:, 2 * idx + 1] = torch.cos( | |||||
| r_pos[:, 2 * idx + 1] | |||||
| / freq**((i[:, 2 * idx + 1] + j[:, 2 * idx + 1]) / d)) | |||||
| return RP | |||||
| def forward(self, x): | |||||
| """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism. | |||||
| :param torch.tensor x: Frame features with shape [T, input_size] | |||||
| :return: A tuple of: | |||||
| y: Weighted features based on the attention weights, with shape [T, input_size] | |||||
| att_weights : The attention weights (before dropout), with shape [T, T] | |||||
| """ | |||||
| outputs = [] | |||||
| for head in range(self.heads): | |||||
| K = self.Wk[head](x) | |||||
| Q = self.Wq[head](x) | |||||
| V = self.Wv[head](x) | |||||
| # Q *= 0.06 # scale factor VASNet | |||||
| # Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) ) | |||||
| energies = torch.matmul(Q, K.transpose(1, 0)) | |||||
| if self.pos_enc is not None: | |||||
| if self.pos_enc == 'absolute': | |||||
| AP = self.getAbsolutePosition(T=energies.shape[0]) | |||||
| energies = energies + AP | |||||
| elif self.pos_enc == 'relative': | |||||
| RP = self.getRelativePosition(T=energies.shape[0]) | |||||
| energies = energies + RP | |||||
| att_weights = self.softmax(energies) | |||||
| _att_weights = self.drop(att_weights) | |||||
| y = torch.matmul(_att_weights, V) | |||||
| # Save the current head output | |||||
| outputs.append(y) | |||||
| y = self.out(torch.cat(outputs, dim=1)) | |||||
| return y, att_weights.clone( | |||||
| ) # for now we don't deal with the weights (probably max or avg pooling) | |||||
| class MultiAttention(nn.Module): | |||||
| def __init__(self, | |||||
| input_size=1024, | |||||
| output_size=1024, | |||||
| freq=10000, | |||||
| pos_enc=None, | |||||
| num_segments=None, | |||||
| heads=1, | |||||
| fusion=None): | |||||
| """ Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters. | |||||
| :param int input_size: The expected input feature size. | |||||
| :param int output_size: The hidden feature size of the attention mechanisms. | |||||
| :param int freq: The frequency of the sinusoidal positional encoding. | |||||
| :param None | str pos_enc: The selected positional encoding [absolute, relative]. | |||||
| :param None | int num_segments: The selected number of segments to split the videos. | |||||
| :param int heads: The selected number of global heads. | |||||
| :param None | str fusion: The selected type of feature fusion. | |||||
| """ | |||||
| super(MultiAttention, self).__init__() | |||||
| # Global Attention, considering differences among all frames | |||||
| self.attention = SelfAttention( | |||||
| input_size=input_size, | |||||
| output_size=output_size, | |||||
| freq=freq, | |||||
| pos_enc=pos_enc, | |||||
| heads=heads) | |||||
| self.num_segments = num_segments | |||||
| if self.num_segments is not None: | |||||
| assert self.num_segments >= 2, 'num_segments must be None or 2+' | |||||
| self.local_attention = nn.ModuleList() | |||||
| for _ in range(self.num_segments): | |||||
| # Local Attention, considering differences among the same segment with reduce hidden size | |||||
| self.local_attention.append( | |||||
| SelfAttention( | |||||
| input_size=input_size, | |||||
| output_size=output_size // num_segments, | |||||
| freq=freq, | |||||
| pos_enc=pos_enc, | |||||
| heads=4)) | |||||
| self.permitted_fusions = ['add', 'mult', 'avg', 'max'] | |||||
| self.fusion = fusion | |||||
| if self.fusion is not None: | |||||
| self.fusion = self.fusion.lower() | |||||
| assert self.fusion in self.permitted_fusions, f'Fusion method must be: {*self.permitted_fusions,}' | |||||
| def forward(self, x): | |||||
| """ Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms. | |||||
| :param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features. | |||||
| :return: A tuple of: | |||||
| weighted_value: Tensor with shape [T, input_size] containing the weighted frame features. | |||||
| attn_weights: Tensor with shape [T, T] containing the attention weights. | |||||
| """ | |||||
| weighted_value, attn_weights = self.attention(x) # global attention | |||||
| if self.num_segments is not None and self.fusion is not None: | |||||
| segment_size = math.ceil(x.shape[0] / self.num_segments) | |||||
| for segment in range(self.num_segments): | |||||
| left_pos = segment * segment_size | |||||
| right_pos = (segment + 1) * segment_size | |||||
| local_x = x[left_pos:right_pos] | |||||
| weighted_local_value, attn_local_weights = self.local_attention[ | |||||
| segment](local_x) # local attentions | |||||
| # Normalize the features vectors | |||||
| weighted_value[left_pos:right_pos] = F.normalize( | |||||
| weighted_value[left_pos:right_pos].clone(), p=2, dim=1) | |||||
| weighted_local_value = F.normalize( | |||||
| weighted_local_value, p=2, dim=1) | |||||
| if self.fusion == 'add': | |||||
| weighted_value[left_pos:right_pos] += weighted_local_value | |||||
| elif self.fusion == 'mult': | |||||
| weighted_value[left_pos:right_pos] *= weighted_local_value | |||||
| elif self.fusion == 'avg': | |||||
| weighted_value[left_pos:right_pos] += weighted_local_value | |||||
| weighted_value[left_pos:right_pos] /= 2 | |||||
| elif self.fusion == 'max': | |||||
| weighted_value[left_pos:right_pos] = torch.max( | |||||
| weighted_value[left_pos:right_pos].clone(), | |||||
| weighted_local_value) | |||||
| return weighted_value, attn_weights | |||||
| class PGL_SUM(nn.Module): | |||||
| def __init__(self, | |||||
| input_size=1024, | |||||
| output_size=1024, | |||||
| freq=10000, | |||||
| pos_enc=None, | |||||
| num_segments=None, | |||||
| heads=1, | |||||
| fusion=None): | |||||
| """ Class wrapping the PGL-SUM model; its key modules and parameters. | |||||
| :param int input_size: The expected input feature size. | |||||
| :param int output_size: The hidden feature size of the attention mechanisms. | |||||
| :param int freq: The frequency of the sinusoidal positional encoding. | |||||
| :param None | str pos_enc: The selected positional encoding [absolute, relative]. | |||||
| :param None | int num_segments: The selected number of segments to split the videos. | |||||
| :param int heads: The selected number of global heads. | |||||
| :param None | str fusion: The selected type of feature fusion. | |||||
| """ | |||||
| super(PGL_SUM, self).__init__() | |||||
| self.attention = MultiAttention( | |||||
| input_size=input_size, | |||||
| output_size=output_size, | |||||
| freq=freq, | |||||
| pos_enc=pos_enc, | |||||
| num_segments=num_segments, | |||||
| heads=heads, | |||||
| fusion=fusion) | |||||
| self.linear_1 = nn.Linear( | |||||
| in_features=input_size, out_features=input_size) | |||||
| self.linear_2 = nn.Linear( | |||||
| in_features=self.linear_1.out_features, out_features=1) | |||||
| self.drop = nn.Dropout(p=0.5) | |||||
| self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6) | |||||
| self.norm_linear = nn.LayerNorm( | |||||
| normalized_shape=self.linear_1.out_features, eps=1e-6) | |||||
| self.relu = nn.ReLU() | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def forward(self, frame_features): | |||||
| """ Produce frames importance scores from the frame features, using the PGL-SUM model. | |||||
| :param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by | |||||
| using the pool5 layer of GoogleNet. | |||||
| :return: A tuple of: | |||||
| y: Tensor with shape [1, T] containing the frames importance scores in [0, 1]. | |||||
| attn_weights: Tensor with shape [T, T] containing the attention weights. | |||||
| """ | |||||
| frame_features = frame_features.reshape(-1, frame_features.shape[-1]) | |||||
| residual = frame_features | |||||
| weighted_value, attn_weights = self.attention(frame_features) | |||||
| y = weighted_value + residual | |||||
| y = self.drop(y) | |||||
| y = self.norm_y(y) | |||||
| # 2-layer NN (Regressor Network) | |||||
| y = self.linear_1(y) | |||||
| y = self.relu(y) | |||||
| y = self.drop(y) | |||||
| y = self.norm_linear(y) | |||||
| y = self.linear_2(y) | |||||
| y = self.sigmoid(y) | |||||
| y = y.view(1, -1) | |||||
| return y, attn_weights | |||||
| @@ -0,0 +1,224 @@ | |||||
| # The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. | |||||
| import os.path as osp | |||||
| from copy import deepcopy | |||||
| from typing import Dict, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.video_summarization.kts.cpd_auto import cpd_auto | |||||
| from modelscope.models.cv.video_summarization.pgl_sum import PGL_SUM | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def get_change_points(video_feat, n_frame): | |||||
| video_feat = np.array(video_feat, np.float32) | |||||
| K = np.dot(video_feat, video_feat.T) | |||||
| change_points, _ = cpd_auto(K, ncp=120, vmax=2.2 / 4.0, lmin=1) | |||||
| change_points = change_points * 15 | |||||
| change_points = np.concatenate(([0], change_points, [n_frame - 1])) | |||||
| temp_change_points = [] | |||||
| for idx in range(len(change_points) - 1): | |||||
| segment = [change_points[idx], change_points[idx + 1] - 1] | |||||
| if idx == len(change_points) - 2: | |||||
| segment = [change_points[idx], change_points[idx + 1]] | |||||
| temp_change_points.append(segment) | |||||
| change_points = np.array(list(temp_change_points)) | |||||
| temp_n_frame_per_seg = [] | |||||
| for change_points_idx in range(len(change_points)): | |||||
| n_frame = change_points[change_points_idx][1] - change_points[ | |||||
| change_points_idx][0] | |||||
| temp_n_frame_per_seg.append(n_frame) | |||||
| n_frame_per_seg = np.array(list(temp_n_frame_per_seg)) | |||||
| return change_points, n_frame_per_seg | |||||
| def knap_sack(W, wt, val, n): | |||||
| """ Maximize the value that a knapsack of capacity W can hold. You can either put the item or discard it, there is | |||||
| no concept of putting some part of item in the knapsack. | |||||
| :param int W: Maximum capacity -in frames- of the knapsack. | |||||
| :param list[int] wt: The weights (lengths -in frames-) of each video shot. | |||||
| :param list[float] val: The values (importance scores) of each video shot. | |||||
| :param int n: The number of the shots. | |||||
| :return: A list containing the indices of the selected shots. | |||||
| """ | |||||
| K = [[0 for _ in range(W + 1)] for _ in range(n + 1)] | |||||
| # Build table K[][] in bottom up manner | |||||
| for i in range(n + 1): | |||||
| for w in range(W + 1): | |||||
| if i == 0 or w == 0: | |||||
| K[i][w] = 0 | |||||
| elif wt[i - 1] <= w: | |||||
| K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], | |||||
| K[i - 1][w]) | |||||
| else: | |||||
| K[i][w] = K[i - 1][w] | |||||
| selected = [] | |||||
| w = W | |||||
| for i in range(n, 0, -1): | |||||
| if K[i][w] != K[i - 1][w]: | |||||
| selected.insert(0, i - 1) | |||||
| w -= wt[i - 1] | |||||
| return selected | |||||
| def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions): | |||||
| """ Generate the automatic machine summary, based on the video shots; the frame importance scores; the number of | |||||
| frames in the original video and the position of the sub-sampled frames of the original video. | |||||
| :param list[np.ndarray] all_shot_bound: The video shots for all the -original- testing videos. | |||||
| :param list[np.ndarray] all_scores: The calculated frame importance scores for all the sub-sampled testing videos. | |||||
| :param list[np.ndarray] all_nframes: The number of frames for all the -original- testing videos. | |||||
| :param list[np.ndarray] all_positions: The position of the sub-sampled frames for all the -original- testing videos. | |||||
| :return: A list containing the indices of the selected frames for all the -original- testing videos. | |||||
| """ | |||||
| all_summaries = [] | |||||
| for video_index in range(len(all_scores)): | |||||
| # Get shots' boundaries | |||||
| shot_bound = all_shot_bound[video_index] # [number_of_shots, 2] | |||||
| frame_init_scores = all_scores[video_index] | |||||
| n_frames = all_nframes[video_index] | |||||
| positions = all_positions[video_index] | |||||
| # Compute the importance scores for the initial frame sequence (not the sub-sampled one) | |||||
| frame_scores = np.zeros(n_frames, dtype=np.float32) | |||||
| if positions.dtype != int: | |||||
| positions = positions.astype(np.int32) | |||||
| if positions[-1] != n_frames: | |||||
| positions = np.concatenate([positions, [n_frames]]) | |||||
| for i in range(len(positions) - 1): | |||||
| pos_left, pos_right = positions[i], positions[i + 1] | |||||
| if i == len(frame_init_scores): | |||||
| frame_scores[pos_left:pos_right] = 0 | |||||
| else: | |||||
| frame_scores[pos_left:pos_right] = frame_init_scores[i] | |||||
| # Compute shot-level importance scores by taking the average importance scores of all frames in the shot | |||||
| shot_imp_scores = [] | |||||
| shot_lengths = [] | |||||
| for shot in shot_bound: | |||||
| shot_lengths.append(shot[1] - shot[0] + 1) | |||||
| shot_imp_scores.append( | |||||
| (frame_scores[shot[0]:shot[1] + 1].mean()).item()) | |||||
| # Select the best shots using the knapsack implementation | |||||
| final_shot = shot_bound[-1] | |||||
| final_max_length = int((final_shot[1] + 1) * 0.15) | |||||
| selected = knap_sack(final_max_length, shot_lengths, shot_imp_scores, | |||||
| len(shot_lengths)) | |||||
| # Select all frames from each selected shot (by setting their value in the summary vector to 1) | |||||
| summary = np.zeros(final_shot[1] + 1, dtype=np.int8) | |||||
| for shot in selected: | |||||
| summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1 | |||||
| all_summaries.append(summary) | |||||
| return all_summaries | |||||
| @MODELS.register_module( | |||||
| Tasks.video_summarization, module_name=Models.video_summarization) | |||||
| class PGLVideoSummarization(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the video summarization model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| self.loss = nn.MSELoss() | |||||
| self.model = PGL_SUM( | |||||
| input_size=1024, | |||||
| output_size=1024, | |||||
| num_segments=4, | |||||
| heads=8, | |||||
| fusion='add', | |||||
| pos_enc='absolute') | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.model = self.model.to(self._device) | |||||
| self.model = self.load_pretrained(self.model, model_path) | |||||
| if self.training: | |||||
| self.model.train() | |||||
| else: | |||||
| self.model.eval() | |||||
| def load_pretrained(self, net, load_path, strict=True, param_key='params'): | |||||
| if isinstance(net, (DataParallel, DistributedDataParallel)): | |||||
| net = net.module | |||||
| load_net = torch.load( | |||||
| load_path, map_location=lambda storage, loc: storage) | |||||
| if param_key is not None: | |||||
| if param_key not in load_net and 'params' in load_net: | |||||
| param_key = 'params' | |||||
| logger.info( | |||||
| f'Loading: {param_key} does not exist, use params.') | |||||
| if param_key in load_net: | |||||
| load_net = load_net[param_key] | |||||
| logger.info( | |||||
| f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||||
| ) | |||||
| # remove unnecessary 'module.' | |||||
| for k, v in deepcopy(load_net).items(): | |||||
| if k.startswith('module.'): | |||||
| load_net[k[7:]] = v | |||||
| load_net.pop(k) | |||||
| net.load_state_dict(load_net, strict=strict) | |||||
| logger.info('load model done.') | |||||
| return net | |||||
| def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| frame_features = input['frame_features'] | |||||
| gtscore = input['gtscore'] | |||||
| preds, attn_weights = self.model(frame_features) | |||||
| return {'loss': self.loss(preds, gtscore)} | |||||
| def _inference_forward(self, input: Dict[str, | |||||
| Tensor]) -> Dict[str, Tensor]: | |||||
| frame_features = input['frame_features'] | |||||
| y, attn_weights = self.model(frame_features) | |||||
| return {'scores': y} | |||||
| def forward(self, input: Dict[str, | |||||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Union[list, Tensor]]: results | |||||
| """ | |||||
| for key, value in input.items(): | |||||
| input[key] = input[key].to(self._device) | |||||
| if self.training: | |||||
| return self._train_forward(input) | |||||
| else: | |||||
| return self._inference_forward(input) | |||||
| @@ -9,7 +9,7 @@ if TYPE_CHECKING: | |||||
| from .torch_base_dataset import TorchTaskDataset | from .torch_base_dataset import TorchTaskDataset | ||||
| from .veco_dataset import VecoDataset | from .veco_dataset import VecoDataset | ||||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | ||||
| from .video_summarization_dataset import VideoSummarizationDataset | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'base': ['TaskDataset'], | 'base': ['TaskDataset'], | ||||
| @@ -17,7 +17,8 @@ else: | |||||
| 'torch_base_dataset': ['TorchTaskDataset'], | 'torch_base_dataset': ['TorchTaskDataset'], | ||||
| 'veco_dataset': ['VecoDataset'], | 'veco_dataset': ['VecoDataset'], | ||||
| 'image_instance_segmentation_coco_dataset': | 'image_instance_segmentation_coco_dataset': | ||||
| ['ImageInstanceSegmentationCocoDataset'] | |||||
| ['ImageInstanceSegmentationCocoDataset'], | |||||
| 'video_summarization_dataset': ['VideoSummarizationDataset'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,69 @@ | |||||
| import os | |||||
| import h5py | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
| TorchTaskDataset | |||||
| class VideoSummarizationDataset(TorchTaskDataset): | |||||
| def __init__(self, mode, opt, root_dir): | |||||
| self.mode = mode | |||||
| self.data_filename = os.path.join(root_dir, opt.dataset_file) | |||||
| self.split_filename = os.path.join(root_dir, opt.split_file) | |||||
| self.split_index = opt.split_index # it represents the current split (varies from 0 to 4) | |||||
| hdf = h5py.File(self.data_filename, 'r') | |||||
| self.list_frame_features, self.list_gtscores = [], [] | |||||
| self.list_user_summary = [] | |||||
| self.list_change_points = [] | |||||
| self.list_n_frames = [] | |||||
| self.list_positions = [] | |||||
| with open(self.split_filename) as f: | |||||
| data = json.loads(f.read()) | |||||
| for i, split in enumerate(data): | |||||
| if i == self.split_index: | |||||
| self.split = split | |||||
| break | |||||
| for video_name in self.split[self.mode + '_keys']: | |||||
| frame_features = torch.Tensor( | |||||
| np.array(hdf[video_name + '/features'])) | |||||
| gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) | |||||
| user_summary = np.array(hdf[f'{video_name}/user_summary']) | |||||
| change_points = np.array(hdf[f'{video_name}/change_points']) | |||||
| n_frames = np.array(hdf[f'{video_name}/n_frames']) | |||||
| positions = np.array(hdf[f'{video_name}/picks']) | |||||
| self.list_frame_features.append(frame_features) | |||||
| self.list_gtscores.append(gtscore) | |||||
| self.list_user_summary.append(user_summary) | |||||
| self.list_change_points.append(change_points) | |||||
| self.list_n_frames.append(n_frames) | |||||
| self.list_positions.append(positions) | |||||
| hdf.close() | |||||
| def __len__(self): | |||||
| self.len = len(self.split[self.mode + '_keys']) | |||||
| return self.len | |||||
| def __getitem__(self, index): | |||||
| frame_features = self.list_frame_features[index] | |||||
| gtscore = self.list_gtscores[index] | |||||
| user_summary = self.list_user_summary[index] | |||||
| change_points = self.list_change_points[index] | |||||
| n_frames = self.list_n_frames[index] | |||||
| positions = self.list_positions[index] | |||||
| return dict( | |||||
| frame_features=frame_features, | |||||
| gtscore=gtscore, | |||||
| user_summary=user_summary, | |||||
| change_points=change_points, | |||||
| n_frames=n_frames, | |||||
| positions=positions) | |||||
| @@ -0,0 +1,109 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.video_summarization import PGLVideoSummarization | |||||
| from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet | |||||
| from modelscope.models.cv.video_summarization.summarizer import ( | |||||
| generate_summary, get_change_points) | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.video_summarization, module_name=Pipelines.video_summarization) | |||||
| class VideoSummarizationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a video summarization pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||||
| logger.info(f'loading model from {model}') | |||||
| googlenet_model_path = osp.join(model, 'bvlc_googlenet.pt') | |||||
| config_path = osp.join(model, ModelFile.CONFIGURATION) | |||||
| logger.info(f'loading config from {config_path}') | |||||
| self.cfg = Config.from_file(config_path) | |||||
| self.googlenet_model = bvlc_googlenet() | |||||
| self.googlenet_model.model.load_state_dict( | |||||
| torch.load( | |||||
| googlenet_model_path, map_location=torch.device(self.device))) | |||||
| self.googlenet_model = self.googlenet_model.to(self.device).eval() | |||||
| self.pgl_model = PGLVideoSummarization(model) | |||||
| self.pgl_model = self.pgl_model.to(self.device).eval() | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if not isinstance(input, str): | |||||
| raise TypeError(f'input should be a str,' | |||||
| f' but got {type(input)}') | |||||
| frames = [] | |||||
| picks = [] | |||||
| cap = cv2.VideoCapture(input) | |||||
| frame_idx = 0 | |||||
| while (cap.isOpened()): | |||||
| ret, frame = cap.read() | |||||
| if not ret: | |||||
| break | |||||
| if frame_idx % 15 == 0: | |||||
| frames.append(frame) | |||||
| picks.append(frame_idx) | |||||
| frame_idx += 1 | |||||
| n_frame = frame_idx | |||||
| result = { | |||||
| 'video_name': input, | |||||
| 'video_frames': np.array(frames), | |||||
| 'n_frame': n_frame, | |||||
| 'picks': np.array(picks) | |||||
| } | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| frame_features = [] | |||||
| for frame in tqdm(input['video_frames']): | |||||
| feat = self.googlenet_model(frame) | |||||
| frame_features.append(feat) | |||||
| change_points, n_frame_per_seg = get_change_points( | |||||
| frame_features, input['n_frame']) | |||||
| summary = self.inference(frame_features, input['n_frame'], | |||||
| input['picks'], change_points) | |||||
| return {OutputKeys.OUTPUT: summary} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| def inference(self, frame_features, n_frames, picks, change_points): | |||||
| frame_features = torch.from_numpy(np.array(frame_features, np.float32)) | |||||
| picks = np.array(picks, np.int32) | |||||
| with torch.no_grad(): | |||||
| results = self.pgl_model(dict(frame_features=frame_features)) | |||||
| scores = results['scores'] | |||||
| if not scores.device.type == 'cpu': | |||||
| scores = scores.cpu() | |||||
| scores = scores.squeeze(0).numpy().tolist() | |||||
| summary = generate_summary([change_points], [scores], [n_frames], | |||||
| [picks])[0] | |||||
| return summary | |||||
| @@ -264,3 +264,28 @@ class ImageInstanceSegmentationPreprocessor(Preprocessor): | |||||
| return None | return None | ||||
| return results | return results | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.cv, module_name=Preprocessors.video_summarization_preprocessor) | |||||
| class VideoSummarizationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data Dict[str, Any] | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| return data | |||||
| @@ -64,6 +64,7 @@ class CVTasks(object): | |||||
| # reid and tracking | # reid and tracking | ||||
| video_single_object_tracking = 'video-single-object-tracking' | video_single_object_tracking = 'video-single-object-tracking' | ||||
| video_summarization = 'video-summarization' | |||||
| image_reid_person = 'image-reid-person' | image_reid_person = 'image-reid-person' | ||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class VideoSummarizationTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| summarization_pipeline = pipeline( | |||||
| Tasks.video_summarization, | |||||
| model='damo/cv_googlenet_pgl-video-summarization') | |||||
| result = summarization_pipeline( | |||||
| 'data/test/videos/video_category_test_video.mp4') | |||||
| print(f'video summarization output: {result}.') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| summarization_pipeline = pipeline(Tasks.video_summarization) | |||||
| result = summarization_pipeline( | |||||
| 'data/test/videos/video_category_test_video.mp4') | |||||
| print(f'video summarization output: {result}.') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.cv.video_summarization import PGLVideoSummarization | |||||
| from modelscope.msdatasets.task_datasets import VideoSummarizationDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class VideoSummarizationTrainerTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/cv_googlenet_pgl-video-summarization' | |||||
| self.cache_path = snapshot_download(self.model_id) | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||||
| self.dataset_train = VideoSummarizationDataset('train', | |||||
| self.config.dataset, | |||||
| self.cache_path) | |||||
| self.dataset_val = VideoSummarizationDataset('test', | |||||
| self.config.dataset, | |||||
| self.cache_path) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.dataset_train, | |||||
| eval_dataset=self.dataset_val, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| model = PGLVideoSummarization.from_pretrained(self.cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.dataset_train, | |||||
| eval_dataset=self.dataset_val, | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||