From 5f326e46f3d4e8f2afc7953bd33d928411cc42fb Mon Sep 17 00:00:00 2001 From: "james.wjg" Date: Wed, 24 Aug 2022 13:41:47 +0800 Subject: [PATCH] cv/video_summarization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增video summarization模型的inference和finetune Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9840532 --- modelscope/metainfo.py | 5 + modelscope/metrics/__init__.py | 2 + modelscope/metrics/builder.py | 2 + .../metrics/video_summarization_metric.py | 78 +++++ modelscope/models/cv/__init__.py | 2 +- .../models/cv/video_summarization/__init__.py | 1 + .../cv/video_summarization/base_model.py | 118 +++++++ .../cv/video_summarization/kts/__init__.py | 0 .../cv/video_summarization/kts/cpd_auto.py | 35 ++ .../cv/video_summarization/kts/cpd_nonlin.py | 102 ++++++ .../models/cv/video_summarization/pgl_sum.py | 311 ++++++++++++++++++ .../cv/video_summarization/summarizer.py | 224 +++++++++++++ .../msdatasets/task_datasets/__init__.py | 5 +- .../video_summarization_dataset.py | 69 ++++ .../cv/video_summarization_pipeline.py | 109 ++++++ modelscope/preprocessors/image.py | 25 ++ modelscope/utils/constant.py | 1 + tests/pipelines/test_video_summarization.py | 32 ++ .../test_video_summarization_trainer.py | 75 +++++ 19 files changed, 1193 insertions(+), 3 deletions(-) create mode 100644 modelscope/metrics/video_summarization_metric.py create mode 100644 modelscope/models/cv/video_summarization/__init__.py create mode 100644 modelscope/models/cv/video_summarization/base_model.py create mode 100644 modelscope/models/cv/video_summarization/kts/__init__.py create mode 100644 modelscope/models/cv/video_summarization/kts/cpd_auto.py create mode 100644 modelscope/models/cv/video_summarization/kts/cpd_nonlin.py create mode 100644 modelscope/models/cv/video_summarization/pgl_sum.py create mode 100644 modelscope/models/cv/video_summarization/summarizer.py create mode 100644 modelscope/msdatasets/task_datasets/video_summarization_dataset.py create mode 100644 modelscope/pipelines/cv/video_summarization_pipeline.py create mode 100644 tests/pipelines/test_video_summarization.py create mode 100644 tests/trainers/test_video_summarization_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 4e759305..d0684ecd 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -21,6 +21,7 @@ class Models(object): body_2d_keypoints = 'body-2d-keypoints' crowd_counting = 'HRNetCrowdCounting' image_reid_person = 'passvitb' + video_summarization = 'pgl-video-summarization' # nlp models bert = 'bert' @@ -113,6 +114,7 @@ class Pipelines(object): tinynas_classification = 'tinynas-classification' crowd_counting = 'hrnet-crowd-counting' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' + video_summarization = 'googlenet_pgl_video_summarization' image_reid_person = 'passvitb-image-reid-person' # nlp tasks @@ -170,6 +172,7 @@ class Trainers(object): # cv trainers image_instance_segmentation = 'image-instance-segmentation' image_portrait_enhancement = 'image-portrait-enhancement' + video_summarization = 'video-summarization' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -194,6 +197,7 @@ class Preprocessors(object): image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' + video_summarization_preprocessor = 'video-summarization-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -246,6 +250,7 @@ class Metrics(object): image_color_enhance_metric = 'image-color-enhance-metric' # metrics for image-portrait-enhancement task image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' + video_summarization_metric = 'video-summarization-metric' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 37f9bfec..d307f7c9 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from .sequence_classification_metric import SequenceClassificationMetric from .text_generation_metric import TextGenerationMetric from .token_classification_metric import TokenClassificationMetric + from .video_summarization_metric import VideoSummarizationMetric else: _import_structure = { @@ -28,6 +29,7 @@ else: 'sequence_classification_metric': ['SequenceClassificationMetric'], 'text_generation_metric': ['TextGenerationMetric'], 'token_classification_metric': ['TokenClassificationMetric'], + 'video_summarization_metric': ['VideoSummarizationMetric'], } import sys diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index bd20d37b..c76fe386 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -15,6 +15,7 @@ class MetricKeys(object): RECALL = 'recall' PSNR = 'psnr' SSIM = 'ssim' + FScore = 'fscore' task_default_metrics = { @@ -28,6 +29,7 @@ task_default_metrics = { Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], Tasks.image_portrait_enhancement: [Metrics.image_portrait_enhancement_metric], + Tasks.video_summarization: [Metrics.video_summarization_metric], } diff --git a/modelscope/metrics/video_summarization_metric.py b/modelscope/metrics/video_summarization_metric.py new file mode 100644 index 00000000..d1867600 --- /dev/null +++ b/modelscope/metrics/video_summarization_metric.py @@ -0,0 +1,78 @@ +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.models.cv.video_summarization.summarizer import \ + generate_summary +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def evaluate_summary(predicted_summary, user_summary, eval_method): + """ Compare the predicted summary with the user defined one(s). + + :param ndarray predicted_summary: The generated summary from our model. + :param ndarray user_summary: The user defined ground truth summaries (or summary). + :param str eval_method: The proposed evaluation method; either 'max' (SumMe) or 'avg' (TVSum). + :return: The reduced fscore based on the eval_method + """ + max_len = max(len(predicted_summary), user_summary.shape[1]) + S = np.zeros(max_len, dtype=int) + G = np.zeros(max_len, dtype=int) + S[:len(predicted_summary)] = predicted_summary + + f_scores = [] + for user in range(user_summary.shape[0]): + G[:user_summary.shape[1]] = user_summary[user] + overlapped = S & G + + # Compute precision, recall, f-score + precision = sum(overlapped) / sum(S) + recall = sum(overlapped) / sum(G) + if precision + recall == 0: + f_scores.append(0) + else: + f_score = 2 * precision * recall * 100 / (precision + recall) + f_scores.append(f_score) + + if eval_method == 'max': + return max(f_scores) + else: + return sum(f_scores) / len(f_scores) + + +def calculate_f_score(outputs: Dict, inputs: Dict): + scores = outputs['scores'] + scores = scores.squeeze(0).cpu().numpy().tolist() + user_summary = inputs['user_summary'].cpu().numpy()[0] + sb = inputs['change_points'].cpu().numpy()[0] + n_frames = inputs['n_frames'].cpu().numpy()[0] + positions = inputs['positions'].cpu().numpy()[0] + summary = generate_summary([sb], [scores], [n_frames], [positions])[0] + f_score = evaluate_summary(summary, user_summary, 'avg') + return f_score + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.video_summarization_metric) +class VideoSummarizationMetric(Metric): + """The metric for video summarization task. + """ + + def __init__(self): + self.inputs = [] + self.outputs = [] + + def add(self, outputs: Dict, inputs: Dict): + self.outputs.append(outputs) + self.inputs.append(inputs) + + def evaluate(self): + f_scores = [ + calculate_f_score(output, input) + for output, input in zip(self.outputs, self.inputs) + ] + + return {MetricKeys.FScore: sum(f_scores) / len(f_scores)} diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index dd7e6724..168ac96c 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -7,4 +7,4 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_to_image_generation, image_to_image_translation, object_detection, product_retrieval_embedding, salient_detection, super_resolution, - video_single_object_tracking, virual_tryon) + video_single_object_tracking, video_summarization, virual_tryon) diff --git a/modelscope/models/cv/video_summarization/__init__.py b/modelscope/models/cv/video_summarization/__init__.py new file mode 100644 index 00000000..064110f7 --- /dev/null +++ b/modelscope/models/cv/video_summarization/__init__.py @@ -0,0 +1 @@ +from .summarizer import PGLVideoSummarization diff --git a/modelscope/models/cv/video_summarization/base_model.py b/modelscope/models/cv/video_summarization/base_model.py new file mode 100644 index 00000000..670da251 --- /dev/null +++ b/modelscope/models/cv/video_summarization/base_model.py @@ -0,0 +1,118 @@ +# The implementation is based on pytorch-caffe-models, available at https://github.com/crowsonkb/pytorch-caffe-models. + +import cv2 +import numpy as np +import torch +import torch.nn as nn + + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, + pool_proj): + super().__init__() + self.conv_1x1 = nn.Conv2d(in_channels, ch1x1, 1) + self.relu_1x1 = nn.ReLU(inplace=True) + self.conv_3x3_reduce = nn.Conv2d(in_channels, ch3x3red, 1) + self.relu_3x3_reduce = nn.ReLU(inplace=True) + self.conv_3x3 = nn.Conv2d(ch3x3red, ch3x3, 3, padding=1) + self.relu_3x3 = nn.ReLU(inplace=True) + self.conv_5x5_reduce = nn.Conv2d(in_channels, ch5x5red, 1) + self.relu_5x5_reduce = nn.ReLU(inplace=True) + self.conv_5x5 = nn.Conv2d(ch5x5red, ch5x5, 5, padding=2) + self.relu_5x5 = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(3, stride=1, padding=1) + self.pool_proj = nn.Conv2d(in_channels, pool_proj, 1) + self.relu_pool_proj = nn.ReLU(inplace=True) + + def forward(self, x): + branch_1 = self.relu_1x1(self.conv_1x1(x)) + branch_2 = self.relu_3x3_reduce(self.conv_3x3_reduce(x)) + branch_2 = self.relu_3x3(self.conv_3x3(branch_2)) + branch_3 = self.relu_5x5_reduce(self.conv_5x5_reduce(x)) + branch_3 = self.relu_5x5(self.conv_5x5(branch_3)) + branch_4 = self.pool(x) + branch_4 = self.relu_pool_proj(self.pool_proj(branch_4)) + return torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1) + + +class GoogLeNet(nn.Sequential): + + def __init__(self, num_classes=1000): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.norm1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.conv2_reduce = nn.Conv2d(64, 64, kernel_size=1) + self.relu2_reduce = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1) + self.relu2 = nn.ReLU(inplace=True) + self.norm2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.pool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.pool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception_4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception_4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception_4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception_4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.pool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128) + self.pool5 = nn.AdaptiveAvgPool2d((1, 1)) + self.loss3_classifier = nn.Linear(1024, num_classes) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.pool1(x) + x = self.norm1(x) + x = self.relu2_reduce(self.conv2_reduce(x)) + x = self.relu2(self.conv2(x)) + x = self.norm2(x) + x = self.pool2(x) + x = self.inception_3a(x) + x = self.inception_3b(x) + x = self.pool3(x) + x = self.inception_4a(x) + x = self.inception_4b(x) + x = self.inception_4c(x) + x = self.inception_4d(x) + x = self.inception_4e(x) + x = self.pool4(x) + x = self.inception_5a(x) + x = self.inception_5b(x) + x = self.pool5(x).flatten(1) + return x + + +class bvlc_googlenet(nn.Module): + + def __init__(self, input_size=224): + """model for the BVLC GoogLeNet, trained on ImageNet. + URL: https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet""" + super(bvlc_googlenet, self).__init__() + + self.model = GoogLeNet(num_classes=1000) + + self.input_size = input_size + self.input_mean = (104.0, 117.0, 123.0) + + def forward(self, frame): + x = cv2.resize(frame, + (self.input_size, self.input_size)).astype(np.float32) + x = (x - self.input_mean).astype(np.float32) + x = np.transpose(x, [2, 0, 1]) + + x = np.expand_dims(x, 0) + x = torch.from_numpy(x) + if not next(self.model.parameters()).device.type == 'cpu': + x = x.cuda() + with torch.no_grad(): + frame_feat = self.model(x) + if not frame_feat.device.type == 'cpu': + frame_feat = frame_feat.cpu() + frame_feat = frame_feat.numpy() + frame_feat = frame_feat / np.linalg.norm(frame_feat) + return frame_feat.reshape(-1) diff --git a/modelscope/models/cv/video_summarization/kts/__init__.py b/modelscope/models/cv/video_summarization/kts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_summarization/kts/cpd_auto.py b/modelscope/models/cv/video_summarization/kts/cpd_auto.py new file mode 100644 index 00000000..a794ca26 --- /dev/null +++ b/modelscope/models/cv/video_summarization/kts/cpd_auto.py @@ -0,0 +1,35 @@ +# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. + +import numpy as np + +from .cpd_nonlin import cpd_nonlin + + +def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): + """Detect change points automatically selecting their number + + :param K: Kernel between each pair of frames in video + :param ncp: Maximum number of change points + :param vmax: Special parameter + :param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x + :param kwargs: Extra parameters for ``cpd_nonlin`` + :return: Tuple (cps, costs) + - cps - best selected change-points + - costs - costs for 0,1,2,...,m change-points + """ + m = ncp + _, scores = cpd_nonlin(K, m, backtrack=False, **kwargs) + + N = K.shape[0] + N2 = N * desc_rate # length of the video before down-sampling + + penalties = np.zeros(m + 1) + # Prevent division by zero (in case of 0 changes) + ncp = np.arange(1, m + 1) + penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1) + + costs = scores / float(N) + penalties + m_best = np.argmin(costs) + cps, scores2 = cpd_nonlin(K, m_best, **kwargs) + + return cps, scores2 diff --git a/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py new file mode 100644 index 00000000..ef2eb6ef --- /dev/null +++ b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py @@ -0,0 +1,102 @@ +# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS. + +import numpy as np + + +def calc_scatters(K): + """Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with + starting frame i and ending frame j} + """ + n = K.shape[0] + K1 = np.cumsum([0] + list(np.diag(K))) + K2 = np.zeros((n + 1, n + 1)) + # TODO: use the fact that K - symmetric + K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) + + diagK2 = np.diag(K2) + + i = np.arange(n).reshape((-1, 1)) + j = np.arange(n).reshape((1, -1)) + + ij_f32 = ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32)) + diagK2_K2 = ( + diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape( + (-1, 1)) - K2[1:, :-1].T - K2[:-1, 1:]) + scatters = ( + K1[1:].reshape((1, -1)) - K1[:-1].reshape( + (-1, 1)) - diagK2_K2 / ij_f32) + + scatters[j < i] = 0 + + return scatters + + +def cpd_nonlin(K, + ncp, + lmin=1, + lmax=100000, + backtrack=True, + verbose=True, + out_scatters=None): + """Change point detection with dynamic programming + + :param K: Square kernel matrix + :param ncp: Number of change points to detect (ncp >= 0) + :param lmin: Minimal length of a segment + :param lmax: Maximal length of a segment + :param backtrack: If False - only evaluate objective scores (to save memory) + :param verbose: If true, print verbose message + :param out_scatters: Output scatters + :return: Tuple (cps, obj_vals) + - cps - detected array of change points: mean is thought to be constant + on [ cps[i], cps[i+1] ) + - obj_vals - values of the objective function for 0..m changepoints + """ + m = int(ncp) # prevent numpy.int64 + + n, n1 = K.shape + assert n == n1, 'Kernel matrix awaited.' + assert (m + 1) * lmin <= n <= (m + 1) * lmax + assert 1 <= lmin <= lmax + + if verbose: + print('Precomputing scatters...') + J = calc_scatters(K) + + if out_scatters is not None: + out_scatters[0] = J + + if verbose: + print('Inferring best change points...') + # Iden[k, l] - value of the objective for k change-points and l first frames + Iden = 1e101 * np.ones((m + 1, n + 1)) + Iden[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] + + if backtrack: + # p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l + p = np.zeros((m + 1, n + 1), dtype=int) + else: + p = np.zeros((1, 1), dtype=int) + + for k in range(1, m + 1): + for l_frame in range((k + 1) * lmin, n + 1): + tmin = max(k * lmin, l_frame - lmax) + tmax = l_frame - lmin + 1 + c = J[tmin:tmax, l_frame - 1].reshape(-1) + \ + Iden[k - 1, tmin:tmax].reshape(-1) + Iden[k, l_frame] = np.min(c) + if backtrack: + p[k, l_frame] = np.argmin(c) + tmin + + # Collect change points + cps = np.zeros(m, dtype=int) + + if backtrack: + cur = n + for k in range(m, 0, -1): + cps[k - 1] = p[k, cur] + cur = cps[k - 1] + + scores = Iden[:, n].copy() + scores[scores > 1e99] = np.inf + return cps, scores diff --git a/modelscope/models/cv/video_summarization/pgl_sum.py b/modelscope/models/cv/video_summarization/pgl_sum.py new file mode 100644 index 00000000..ab3010c9 --- /dev/null +++ b/modelscope/models/cv/video_summarization/pgl_sum.py @@ -0,0 +1,311 @@ +# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SelfAttention(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + heads=1, + pos_enc=None): + """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V + + :param int input_size: Feature input size of Q, K, V. + :param int output_size: Feature -hidden- size of Q, K, V. + :param int freq: The frequency of the sinusoidal positional encoding. + :param int heads: Number of heads for the attention module. + :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative]. + """ + super(SelfAttention, self).__init__() + + self.permitted_encodings = ['absolute', 'relative'] + if pos_enc is not None: + pos_enc = pos_enc.lower() + assert pos_enc in self.permitted_encodings, f'Supported encodings: {*self.permitted_encodings,}' + + self.input_size = input_size + self.output_size = output_size + self.heads = heads + self.pos_enc = pos_enc + self.freq = freq + self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList( + ), nn.ModuleList() + for _ in range(self.heads): + self.Wk.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.Wq.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.Wv.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.out = nn.Linear( + in_features=output_size, out_features=input_size, bias=False) + + self.softmax = nn.Softmax(dim=-1) + self.drop = nn.Dropout(p=0.5) + + def getAbsolutePosition(self, T): + """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame. + Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762) + + :param int T: Number of frames contained in Q, K and V + :return: Tensor with shape [T, T] + """ + freq = self.freq + d = self.input_size + + pos = torch.tensor([k for k in range(T)], + device=self.out.weight.device) + i = torch.tensor([k for k in range(T // 2)], + device=self.out.weight.device) + + # Reshape tensors each pos_k for each i indices + pos = pos.reshape(pos.shape[0], 1) + pos = pos.repeat_interleave(i.shape[0], dim=1) + i = i.repeat(pos.shape[0], 1) + + AP = torch.zeros(T, T, device=self.out.weight.device) + AP[pos, 2 * i] = torch.sin(pos / freq**((2 * i) / d)) + AP[pos, 2 * i + 1] = torch.cos(pos / freq**((2 * i) / d)) + return AP + + def getRelativePosition(self, T): + """Calculate the sinusoidal positional encoding based on the relative position of each considered frame. + r_pos calculations as here: https://theaisummer.com/positional-embeddings/ + + :param int T: Number of frames contained in Q, K and V + :return: Tensor with shape [T, T] + """ + freq = self.freq + d = 2 * T + min_rpos = -(T - 1) + + i = torch.tensor([k for k in range(T)], device=self.out.weight.device) + j = torch.tensor([k for k in range(T)], device=self.out.weight.device) + + # Reshape tensors each i for each j indices + i = i.reshape(i.shape[0], 1) + i = i.repeat_interleave(i.shape[0], dim=1) + j = j.repeat(i.shape[0], 1) + + # Calculate the relative positions + r_pos = j - i - min_rpos + + RP = torch.zeros(T, T, device=self.out.weight.device) + idx = torch.tensor([k for k in range(T // 2)], + device=self.out.weight.device) + RP[:, 2 * idx] = torch.sin( + r_pos[:, 2 * idx] / freq**((i[:, 2 * idx] + j[:, 2 * idx]) / d)) + RP[:, 2 * idx + 1] = torch.cos( + r_pos[:, 2 * idx + 1] + / freq**((i[:, 2 * idx + 1] + j[:, 2 * idx + 1]) / d)) + return RP + + def forward(self, x): + """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism. + + :param torch.tensor x: Frame features with shape [T, input_size] + :return: A tuple of: + y: Weighted features based on the attention weights, with shape [T, input_size] + att_weights : The attention weights (before dropout), with shape [T, T] + """ + outputs = [] + for head in range(self.heads): + K = self.Wk[head](x) + Q = self.Wq[head](x) + V = self.Wv[head](x) + + # Q *= 0.06 # scale factor VASNet + # Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) ) + energies = torch.matmul(Q, K.transpose(1, 0)) + if self.pos_enc is not None: + if self.pos_enc == 'absolute': + AP = self.getAbsolutePosition(T=energies.shape[0]) + energies = energies + AP + elif self.pos_enc == 'relative': + RP = self.getRelativePosition(T=energies.shape[0]) + energies = energies + RP + + att_weights = self.softmax(energies) + _att_weights = self.drop(att_weights) + y = torch.matmul(_att_weights, V) + + # Save the current head output + outputs.append(y) + y = self.out(torch.cat(outputs, dim=1)) + return y, att_weights.clone( + ) # for now we don't deal with the weights (probably max or avg pooling) + + +class MultiAttention(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + pos_enc=None, + num_segments=None, + heads=1, + fusion=None): + """ Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters. + + :param int input_size: The expected input feature size. + :param int output_size: The hidden feature size of the attention mechanisms. + :param int freq: The frequency of the sinusoidal positional encoding. + :param None | str pos_enc: The selected positional encoding [absolute, relative]. + :param None | int num_segments: The selected number of segments to split the videos. + :param int heads: The selected number of global heads. + :param None | str fusion: The selected type of feature fusion. + """ + super(MultiAttention, self).__init__() + + # Global Attention, considering differences among all frames + self.attention = SelfAttention( + input_size=input_size, + output_size=output_size, + freq=freq, + pos_enc=pos_enc, + heads=heads) + + self.num_segments = num_segments + if self.num_segments is not None: + assert self.num_segments >= 2, 'num_segments must be None or 2+' + self.local_attention = nn.ModuleList() + for _ in range(self.num_segments): + # Local Attention, considering differences among the same segment with reduce hidden size + self.local_attention.append( + SelfAttention( + input_size=input_size, + output_size=output_size // num_segments, + freq=freq, + pos_enc=pos_enc, + heads=4)) + self.permitted_fusions = ['add', 'mult', 'avg', 'max'] + self.fusion = fusion + if self.fusion is not None: + self.fusion = self.fusion.lower() + assert self.fusion in self.permitted_fusions, f'Fusion method must be: {*self.permitted_fusions,}' + + def forward(self, x): + """ Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms. + + :param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features. + :return: A tuple of: + weighted_value: Tensor with shape [T, input_size] containing the weighted frame features. + attn_weights: Tensor with shape [T, T] containing the attention weights. + """ + weighted_value, attn_weights = self.attention(x) # global attention + + if self.num_segments is not None and self.fusion is not None: + segment_size = math.ceil(x.shape[0] / self.num_segments) + for segment in range(self.num_segments): + left_pos = segment * segment_size + right_pos = (segment + 1) * segment_size + local_x = x[left_pos:right_pos] + weighted_local_value, attn_local_weights = self.local_attention[ + segment](local_x) # local attentions + + # Normalize the features vectors + weighted_value[left_pos:right_pos] = F.normalize( + weighted_value[left_pos:right_pos].clone(), p=2, dim=1) + weighted_local_value = F.normalize( + weighted_local_value, p=2, dim=1) + if self.fusion == 'add': + weighted_value[left_pos:right_pos] += weighted_local_value + elif self.fusion == 'mult': + weighted_value[left_pos:right_pos] *= weighted_local_value + elif self.fusion == 'avg': + weighted_value[left_pos:right_pos] += weighted_local_value + weighted_value[left_pos:right_pos] /= 2 + elif self.fusion == 'max': + weighted_value[left_pos:right_pos] = torch.max( + weighted_value[left_pos:right_pos].clone(), + weighted_local_value) + + return weighted_value, attn_weights + + +class PGL_SUM(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + pos_enc=None, + num_segments=None, + heads=1, + fusion=None): + """ Class wrapping the PGL-SUM model; its key modules and parameters. + + :param int input_size: The expected input feature size. + :param int output_size: The hidden feature size of the attention mechanisms. + :param int freq: The frequency of the sinusoidal positional encoding. + :param None | str pos_enc: The selected positional encoding [absolute, relative]. + :param None | int num_segments: The selected number of segments to split the videos. + :param int heads: The selected number of global heads. + :param None | str fusion: The selected type of feature fusion. + """ + super(PGL_SUM, self).__init__() + + self.attention = MultiAttention( + input_size=input_size, + output_size=output_size, + freq=freq, + pos_enc=pos_enc, + num_segments=num_segments, + heads=heads, + fusion=fusion) + self.linear_1 = nn.Linear( + in_features=input_size, out_features=input_size) + self.linear_2 = nn.Linear( + in_features=self.linear_1.out_features, out_features=1) + + self.drop = nn.Dropout(p=0.5) + self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6) + self.norm_linear = nn.LayerNorm( + normalized_shape=self.linear_1.out_features, eps=1e-6) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, frame_features): + """ Produce frames importance scores from the frame features, using the PGL-SUM model. + + :param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by + using the pool5 layer of GoogleNet. + :return: A tuple of: + y: Tensor with shape [1, T] containing the frames importance scores in [0, 1]. + attn_weights: Tensor with shape [T, T] containing the attention weights. + """ + frame_features = frame_features.reshape(-1, frame_features.shape[-1]) + residual = frame_features + weighted_value, attn_weights = self.attention(frame_features) + y = weighted_value + residual + y = self.drop(y) + y = self.norm_y(y) + + # 2-layer NN (Regressor Network) + y = self.linear_1(y) + y = self.relu(y) + y = self.drop(y) + y = self.norm_linear(y) + + y = self.linear_2(y) + y = self.sigmoid(y) + y = y.view(1, -1) + + return y, attn_weights diff --git a/modelscope/models/cv/video_summarization/summarizer.py b/modelscope/models/cv/video_summarization/summarizer.py new file mode 100644 index 00000000..c95da025 --- /dev/null +++ b/modelscope/models/cv/video_summarization/summarizer.py @@ -0,0 +1,224 @@ +# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM. + +import os.path as osp +from copy import deepcopy +from typing import Dict, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.video_summarization.kts.cpd_auto import cpd_auto +from modelscope.models.cv.video_summarization.pgl_sum import PGL_SUM +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def get_change_points(video_feat, n_frame): + video_feat = np.array(video_feat, np.float32) + K = np.dot(video_feat, video_feat.T) + change_points, _ = cpd_auto(K, ncp=120, vmax=2.2 / 4.0, lmin=1) + change_points = change_points * 15 + change_points = np.concatenate(([0], change_points, [n_frame - 1])) + + temp_change_points = [] + for idx in range(len(change_points) - 1): + segment = [change_points[idx], change_points[idx + 1] - 1] + if idx == len(change_points) - 2: + segment = [change_points[idx], change_points[idx + 1]] + + temp_change_points.append(segment) + change_points = np.array(list(temp_change_points)) + + temp_n_frame_per_seg = [] + for change_points_idx in range(len(change_points)): + n_frame = change_points[change_points_idx][1] - change_points[ + change_points_idx][0] + temp_n_frame_per_seg.append(n_frame) + n_frame_per_seg = np.array(list(temp_n_frame_per_seg)) + + return change_points, n_frame_per_seg + + +def knap_sack(W, wt, val, n): + """ Maximize the value that a knapsack of capacity W can hold. You can either put the item or discard it, there is + no concept of putting some part of item in the knapsack. + + :param int W: Maximum capacity -in frames- of the knapsack. + :param list[int] wt: The weights (lengths -in frames-) of each video shot. + :param list[float] val: The values (importance scores) of each video shot. + :param int n: The number of the shots. + :return: A list containing the indices of the selected shots. + """ + K = [[0 for _ in range(W + 1)] for _ in range(n + 1)] + + # Build table K[][] in bottom up manner + for i in range(n + 1): + for w in range(W + 1): + if i == 0 or w == 0: + K[i][w] = 0 + elif wt[i - 1] <= w: + K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], + K[i - 1][w]) + else: + K[i][w] = K[i - 1][w] + + selected = [] + w = W + for i in range(n, 0, -1): + if K[i][w] != K[i - 1][w]: + selected.insert(0, i - 1) + w -= wt[i - 1] + + return selected + + +def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions): + """ Generate the automatic machine summary, based on the video shots; the frame importance scores; the number of + frames in the original video and the position of the sub-sampled frames of the original video. + + :param list[np.ndarray] all_shot_bound: The video shots for all the -original- testing videos. + :param list[np.ndarray] all_scores: The calculated frame importance scores for all the sub-sampled testing videos. + :param list[np.ndarray] all_nframes: The number of frames for all the -original- testing videos. + :param list[np.ndarray] all_positions: The position of the sub-sampled frames for all the -original- testing videos. + :return: A list containing the indices of the selected frames for all the -original- testing videos. + """ + all_summaries = [] + for video_index in range(len(all_scores)): + # Get shots' boundaries + shot_bound = all_shot_bound[video_index] # [number_of_shots, 2] + frame_init_scores = all_scores[video_index] + n_frames = all_nframes[video_index] + positions = all_positions[video_index] + + # Compute the importance scores for the initial frame sequence (not the sub-sampled one) + frame_scores = np.zeros(n_frames, dtype=np.float32) + if positions.dtype != int: + positions = positions.astype(np.int32) + if positions[-1] != n_frames: + positions = np.concatenate([positions, [n_frames]]) + for i in range(len(positions) - 1): + pos_left, pos_right = positions[i], positions[i + 1] + if i == len(frame_init_scores): + frame_scores[pos_left:pos_right] = 0 + else: + frame_scores[pos_left:pos_right] = frame_init_scores[i] + + # Compute shot-level importance scores by taking the average importance scores of all frames in the shot + shot_imp_scores = [] + shot_lengths = [] + for shot in shot_bound: + shot_lengths.append(shot[1] - shot[0] + 1) + shot_imp_scores.append( + (frame_scores[shot[0]:shot[1] + 1].mean()).item()) + + # Select the best shots using the knapsack implementation + final_shot = shot_bound[-1] + final_max_length = int((final_shot[1] + 1) * 0.15) + + selected = knap_sack(final_max_length, shot_lengths, shot_imp_scores, + len(shot_lengths)) + + # Select all frames from each selected shot (by setting their value in the summary vector to 1) + summary = np.zeros(final_shot[1] + 1, dtype=np.int8) + for shot in selected: + summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1 + + all_summaries.append(summary) + + return all_summaries + + +@MODELS.register_module( + Tasks.video_summarization, module_name=Models.video_summarization) +class PGLVideoSummarization(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the video summarization model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.loss = nn.MSELoss() + self.model = PGL_SUM( + input_size=1024, + output_size=1024, + num_segments=4, + heads=8, + fusion='add', + pos_enc='absolute') + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device) + + self.model = self.load_pretrained(self.model, model_path) + + if self.training: + self.model.train() + else: + self.model.eval() + + def load_pretrained(self, net, load_path, strict=True, param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + gtscore = input['gtscore'] + preds, attn_weights = self.model(frame_features) + return {'loss': self.loss(preds, gtscore)} + + def _inference_forward(self, input: Dict[str, + Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + y, attn_weights = self.model(frame_features) + return {'scores': y} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + if self.training: + return self._train_forward(input) + else: + return self._inference_forward(input) diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index c80f8cd5..1905bf39 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from .torch_base_dataset import TorchTaskDataset from .veco_dataset import VecoDataset from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset - + from .video_summarization_dataset import VideoSummarizationDataset else: _import_structure = { 'base': ['TaskDataset'], @@ -17,7 +17,8 @@ else: 'torch_base_dataset': ['TorchTaskDataset'], 'veco_dataset': ['VecoDataset'], 'image_instance_segmentation_coco_dataset': - ['ImageInstanceSegmentationCocoDataset'] + ['ImageInstanceSegmentationCocoDataset'], + 'video_summarization_dataset': ['VideoSummarizationDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py new file mode 100644 index 00000000..89deb7ba --- /dev/null +++ b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py @@ -0,0 +1,69 @@ +import os + +import h5py +import json +import numpy as np +import torch + +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset + + +class VideoSummarizationDataset(TorchTaskDataset): + + def __init__(self, mode, opt, root_dir): + self.mode = mode + self.data_filename = os.path.join(root_dir, opt.dataset_file) + self.split_filename = os.path.join(root_dir, opt.split_file) + self.split_index = opt.split_index # it represents the current split (varies from 0 to 4) + hdf = h5py.File(self.data_filename, 'r') + self.list_frame_features, self.list_gtscores = [], [] + self.list_user_summary = [] + self.list_change_points = [] + self.list_n_frames = [] + self.list_positions = [] + + with open(self.split_filename) as f: + data = json.loads(f.read()) + for i, split in enumerate(data): + if i == self.split_index: + self.split = split + break + + for video_name in self.split[self.mode + '_keys']: + frame_features = torch.Tensor( + np.array(hdf[video_name + '/features'])) + gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) + user_summary = np.array(hdf[f'{video_name}/user_summary']) + change_points = np.array(hdf[f'{video_name}/change_points']) + n_frames = np.array(hdf[f'{video_name}/n_frames']) + positions = np.array(hdf[f'{video_name}/picks']) + + self.list_frame_features.append(frame_features) + self.list_gtscores.append(gtscore) + self.list_user_summary.append(user_summary) + self.list_change_points.append(change_points) + self.list_n_frames.append(n_frames) + self.list_positions.append(positions) + + hdf.close() + + def __len__(self): + self.len = len(self.split[self.mode + '_keys']) + return self.len + + def __getitem__(self, index): + frame_features = self.list_frame_features[index] + gtscore = self.list_gtscores[index] + user_summary = self.list_user_summary[index] + change_points = self.list_change_points[index] + n_frames = self.list_n_frames[index] + positions = self.list_positions[index] + + return dict( + frame_features=frame_features, + gtscore=gtscore, + user_summary=user_summary, + change_points=change_points, + n_frames=n_frames, + positions=positions) diff --git a/modelscope/pipelines/cv/video_summarization_pipeline.py b/modelscope/pipelines/cv/video_summarization_pipeline.py new file mode 100644 index 00000000..9ed9c867 --- /dev/null +++ b/modelscope/pipelines/cv/video_summarization_pipeline.py @@ -0,0 +1,109 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_summarization import PGLVideoSummarization +from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet +from modelscope.models.cv.video_summarization.summarizer import ( + generate_summary, get_change_points) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_summarization, module_name=Pipelines.video_summarization) +class VideoSummarizationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video summarization pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from {model}') + googlenet_model_path = osp.join(model, 'bvlc_googlenet.pt') + config_path = osp.join(model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + + self.googlenet_model = bvlc_googlenet() + self.googlenet_model.model.load_state_dict( + torch.load( + googlenet_model_path, map_location=torch.device(self.device))) + self.googlenet_model = self.googlenet_model.to(self.device).eval() + + self.pgl_model = PGLVideoSummarization(model) + self.pgl_model = self.pgl_model.to(self.device).eval() + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if not isinstance(input, str): + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + frames = [] + picks = [] + cap = cv2.VideoCapture(input) + frame_idx = 0 + while (cap.isOpened()): + ret, frame = cap.read() + if not ret: + break + if frame_idx % 15 == 0: + frames.append(frame) + picks.append(frame_idx) + frame_idx += 1 + n_frame = frame_idx + + result = { + 'video_name': input, + 'video_frames': np.array(frames), + 'n_frame': n_frame, + 'picks': np.array(picks) + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + frame_features = [] + for frame in tqdm(input['video_frames']): + feat = self.googlenet_model(frame) + frame_features.append(feat) + + change_points, n_frame_per_seg = get_change_points( + frame_features, input['n_frame']) + + summary = self.inference(frame_features, input['n_frame'], + input['picks'], change_points) + + return {OutputKeys.OUTPUT: summary} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + def inference(self, frame_features, n_frames, picks, change_points): + frame_features = torch.from_numpy(np.array(frame_features, np.float32)) + picks = np.array(picks, np.int32) + + with torch.no_grad(): + results = self.pgl_model(dict(frame_features=frame_features)) + scores = results['scores'] + if not scores.device.type == 'cpu': + scores = scores.cpu() + scores = scores.squeeze(0).numpy().tolist() + summary = generate_summary([change_points], [scores], [n_frames], + [picks])[0] + + return summary diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 6932371d..60f6e0eb 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -264,3 +264,28 @@ class ImageInstanceSegmentationPreprocessor(Preprocessor): return None return results + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.video_summarization_preprocessor) +class VideoSummarizationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """ + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict[str, Any] + + Returns: + Dict[str, Any]: the preprocessed data + """ + return data diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index fd679d74..d914767b 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -64,6 +64,7 @@ class CVTasks(object): # reid and tracking video_single_object_tracking = 'video-single-object-tracking' + video_summarization = 'video-summarization' image_reid_person = 'image-reid-person' diff --git a/tests/pipelines/test_video_summarization.py b/tests/pipelines/test_video_summarization.py new file mode 100644 index 00000000..36724332 --- /dev/null +++ b/tests/pipelines/test_video_summarization.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VideoSummarizationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + + summarization_pipeline = pipeline( + Tasks.video_summarization, + model='damo/cv_googlenet_pgl-video-summarization') + result = summarization_pipeline( + 'data/test/videos/video_category_test_video.mp4') + + print(f'video summarization output: {result}.') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_default_model(self): + summarization_pipeline = pipeline(Tasks.video_summarization) + result = summarization_pipeline( + 'data/test/videos/video_category_test_video.mp4') + + print(f'video summarization output: {result}.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_video_summarization_trainer.py b/tests/trainers/test_video_summarization_trainer.py new file mode 100644 index 00000000..1cea1eea --- /dev/null +++ b/tests/trainers/test_video_summarization_trainer.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.video_summarization import PGLVideoSummarization +from modelscope.msdatasets.task_datasets import VideoSummarizationDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class VideoSummarizationTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_googlenet_pgl-video-summarization' + self.cache_path = snapshot_download(self.model_id) + self.config = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + self.dataset_train = VideoSummarizationDataset('train', + self.config.dataset, + self.cache_path) + self.dataset_val = VideoSummarizationDataset('test', + self.config.dataset, + self.cache_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model = PGLVideoSummarization.from_pretrained(self.cache_path) + kwargs = dict( + cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()