Browse Source

cv/video_summarization

新增video summarization模型的inference和finetune
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9840532
master
james.wjg yingda.chen 3 years ago
parent
commit
5f326e46f3
19 changed files with 1193 additions and 3 deletions
  1. +5
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/metrics/__init__.py
  3. +2
    -0
      modelscope/metrics/builder.py
  4. +78
    -0
      modelscope/metrics/video_summarization_metric.py
  5. +1
    -1
      modelscope/models/cv/__init__.py
  6. +1
    -0
      modelscope/models/cv/video_summarization/__init__.py
  7. +118
    -0
      modelscope/models/cv/video_summarization/base_model.py
  8. +0
    -0
      modelscope/models/cv/video_summarization/kts/__init__.py
  9. +35
    -0
      modelscope/models/cv/video_summarization/kts/cpd_auto.py
  10. +102
    -0
      modelscope/models/cv/video_summarization/kts/cpd_nonlin.py
  11. +311
    -0
      modelscope/models/cv/video_summarization/pgl_sum.py
  12. +224
    -0
      modelscope/models/cv/video_summarization/summarizer.py
  13. +3
    -2
      modelscope/msdatasets/task_datasets/__init__.py
  14. +69
    -0
      modelscope/msdatasets/task_datasets/video_summarization_dataset.py
  15. +109
    -0
      modelscope/pipelines/cv/video_summarization_pipeline.py
  16. +25
    -0
      modelscope/preprocessors/image.py
  17. +1
    -0
      modelscope/utils/constant.py
  18. +32
    -0
      tests/pipelines/test_video_summarization.py
  19. +75
    -0
      tests/trainers/test_video_summarization_trainer.py

+ 5
- 0
modelscope/metainfo.py View File

@@ -21,6 +21,7 @@ class Models(object):
body_2d_keypoints = 'body-2d-keypoints' body_2d_keypoints = 'body-2d-keypoints'
crowd_counting = 'HRNetCrowdCounting' crowd_counting = 'HRNetCrowdCounting'
image_reid_person = 'passvitb' image_reid_person = 'passvitb'
video_summarization = 'pgl-video-summarization'


# nlp models # nlp models
bert = 'bert' bert = 'bert'
@@ -113,6 +114,7 @@ class Pipelines(object):
tinynas_classification = 'tinynas-classification' tinynas_classification = 'tinynas-classification'
crowd_counting = 'hrnet-crowd-counting' crowd_counting = 'hrnet-crowd-counting'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
video_summarization = 'googlenet_pgl_video_summarization'
image_reid_person = 'passvitb-image-reid-person' image_reid_person = 'passvitb-image-reid-person'


# nlp tasks # nlp tasks
@@ -170,6 +172,7 @@ class Trainers(object):
# cv trainers # cv trainers
image_instance_segmentation = 'image-instance-segmentation' image_instance_segmentation = 'image-instance-segmentation'
image_portrait_enhancement = 'image-portrait-enhancement' image_portrait_enhancement = 'image-portrait-enhancement'
video_summarization = 'video-summarization'


# nlp trainers # nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis' bert_sentiment_analysis = 'bert-sentiment-analysis'
@@ -194,6 +197,7 @@ class Preprocessors(object):
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'


# nlp preprocessor # nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer' sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -246,6 +250,7 @@ class Metrics(object):
image_color_enhance_metric = 'image-color-enhance-metric' image_color_enhance_metric = 'image-color-enhance-metric'
# metrics for image-portrait-enhancement task # metrics for image-portrait-enhancement task
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' image_portrait_enhancement_metric = 'image-portrait-enhancement-metric'
video_summarization_metric = 'video-summarization-metric'




class Optimizers(object): class Optimizers(object):


+ 2
- 0
modelscope/metrics/__init__.py View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
from .sequence_classification_metric import SequenceClassificationMetric from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric from .text_generation_metric import TextGenerationMetric
from .token_classification_metric import TokenClassificationMetric from .token_classification_metric import TokenClassificationMetric
from .video_summarization_metric import VideoSummarizationMetric


else: else:
_import_structure = { _import_structure = {
@@ -28,6 +29,7 @@ else:
'sequence_classification_metric': ['SequenceClassificationMetric'], 'sequence_classification_metric': ['SequenceClassificationMetric'],
'text_generation_metric': ['TextGenerationMetric'], 'text_generation_metric': ['TextGenerationMetric'],
'token_classification_metric': ['TokenClassificationMetric'], 'token_classification_metric': ['TokenClassificationMetric'],
'video_summarization_metric': ['VideoSummarizationMetric'],
} }


import sys import sys


+ 2
- 0
modelscope/metrics/builder.py View File

@@ -15,6 +15,7 @@ class MetricKeys(object):
RECALL = 'recall' RECALL = 'recall'
PSNR = 'psnr' PSNR = 'psnr'
SSIM = 'ssim' SSIM = 'ssim'
FScore = 'fscore'




task_default_metrics = { task_default_metrics = {
@@ -28,6 +29,7 @@ task_default_metrics = {
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric],
Tasks.image_portrait_enhancement: Tasks.image_portrait_enhancement:
[Metrics.image_portrait_enhancement_metric], [Metrics.image_portrait_enhancement_metric],
Tasks.video_summarization: [Metrics.video_summarization_metric],
} }






+ 78
- 0
modelscope/metrics/video_summarization_metric.py View File

@@ -0,0 +1,78 @@
from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.models.cv.video_summarization.summarizer import \
generate_summary
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


def evaluate_summary(predicted_summary, user_summary, eval_method):
""" Compare the predicted summary with the user defined one(s).

:param ndarray predicted_summary: The generated summary from our model.
:param ndarray user_summary: The user defined ground truth summaries (or summary).
:param str eval_method: The proposed evaluation method; either 'max' (SumMe) or 'avg' (TVSum).
:return: The reduced fscore based on the eval_method
"""
max_len = max(len(predicted_summary), user_summary.shape[1])
S = np.zeros(max_len, dtype=int)
G = np.zeros(max_len, dtype=int)
S[:len(predicted_summary)] = predicted_summary

f_scores = []
for user in range(user_summary.shape[0]):
G[:user_summary.shape[1]] = user_summary[user]
overlapped = S & G

# Compute precision, recall, f-score
precision = sum(overlapped) / sum(S)
recall = sum(overlapped) / sum(G)
if precision + recall == 0:
f_scores.append(0)
else:
f_score = 2 * precision * recall * 100 / (precision + recall)
f_scores.append(f_score)

if eval_method == 'max':
return max(f_scores)
else:
return sum(f_scores) / len(f_scores)


def calculate_f_score(outputs: Dict, inputs: Dict):
scores = outputs['scores']
scores = scores.squeeze(0).cpu().numpy().tolist()
user_summary = inputs['user_summary'].cpu().numpy()[0]
sb = inputs['change_points'].cpu().numpy()[0]
n_frames = inputs['n_frames'].cpu().numpy()[0]
positions = inputs['positions'].cpu().numpy()[0]
summary = generate_summary([sb], [scores], [n_frames], [positions])[0]
f_score = evaluate_summary(summary, user_summary, 'avg')
return f_score


@METRICS.register_module(
group_key=default_group, module_name=Metrics.video_summarization_metric)
class VideoSummarizationMetric(Metric):
"""The metric for video summarization task.
"""

def __init__(self):
self.inputs = []
self.outputs = []

def add(self, outputs: Dict, inputs: Dict):
self.outputs.append(outputs)
self.inputs.append(inputs)

def evaluate(self):
f_scores = [
calculate_f_score(output, input)
for output, input in zip(self.outputs, self.inputs)
]

return {MetricKeys.FScore: sum(f_scores) / len(f_scores)}

+ 1
- 1
modelscope/models/cv/__init__.py View File

@@ -7,4 +7,4 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_to_image_generation, image_to_image_translation, image_to_image_generation, image_to_image_translation,
object_detection, product_retrieval_embedding, object_detection, product_retrieval_embedding,
salient_detection, super_resolution, salient_detection, super_resolution,
video_single_object_tracking, virual_tryon)
video_single_object_tracking, video_summarization, virual_tryon)

+ 1
- 0
modelscope/models/cv/video_summarization/__init__.py View File

@@ -0,0 +1 @@
from .summarizer import PGLVideoSummarization

+ 118
- 0
modelscope/models/cv/video_summarization/base_model.py View File

@@ -0,0 +1,118 @@
# The implementation is based on pytorch-caffe-models, available at https://github.com/crowsonkb/pytorch-caffe-models.

import cv2
import numpy as np
import torch
import torch.nn as nn


class Inception(nn.Module):

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5,
pool_proj):
super().__init__()
self.conv_1x1 = nn.Conv2d(in_channels, ch1x1, 1)
self.relu_1x1 = nn.ReLU(inplace=True)
self.conv_3x3_reduce = nn.Conv2d(in_channels, ch3x3red, 1)
self.relu_3x3_reduce = nn.ReLU(inplace=True)
self.conv_3x3 = nn.Conv2d(ch3x3red, ch3x3, 3, padding=1)
self.relu_3x3 = nn.ReLU(inplace=True)
self.conv_5x5_reduce = nn.Conv2d(in_channels, ch5x5red, 1)
self.relu_5x5_reduce = nn.ReLU(inplace=True)
self.conv_5x5 = nn.Conv2d(ch5x5red, ch5x5, 5, padding=2)
self.relu_5x5 = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(3, stride=1, padding=1)
self.pool_proj = nn.Conv2d(in_channels, pool_proj, 1)
self.relu_pool_proj = nn.ReLU(inplace=True)

def forward(self, x):
branch_1 = self.relu_1x1(self.conv_1x1(x))
branch_2 = self.relu_3x3_reduce(self.conv_3x3_reduce(x))
branch_2 = self.relu_3x3(self.conv_3x3(branch_2))
branch_3 = self.relu_5x5_reduce(self.conv_5x5_reduce(x))
branch_3 = self.relu_5x5(self.conv_5x5(branch_3))
branch_4 = self.pool(x)
branch_4 = self.relu_pool_proj(self.pool_proj(branch_4))
return torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1)


class GoogLeNet(nn.Sequential):

def __init__(self, num_classes=1000):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.norm1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
self.conv2_reduce = nn.Conv2d(64, 64, kernel_size=1)
self.relu2_reduce = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
self.relu2 = nn.ReLU(inplace=True)
self.norm2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
self.pool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.pool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception_4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.inception_4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.inception_4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.inception_4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.inception_4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.pool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception_5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.pool5 = nn.AdaptiveAvgPool2d((1, 1))
self.loss3_classifier = nn.Linear(1024, num_classes)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.pool1(x)
x = self.norm1(x)
x = self.relu2_reduce(self.conv2_reduce(x))
x = self.relu2(self.conv2(x))
x = self.norm2(x)
x = self.pool2(x)
x = self.inception_3a(x)
x = self.inception_3b(x)
x = self.pool3(x)
x = self.inception_4a(x)
x = self.inception_4b(x)
x = self.inception_4c(x)
x = self.inception_4d(x)
x = self.inception_4e(x)
x = self.pool4(x)
x = self.inception_5a(x)
x = self.inception_5b(x)
x = self.pool5(x).flatten(1)
return x


class bvlc_googlenet(nn.Module):

def __init__(self, input_size=224):
"""model for the BVLC GoogLeNet, trained on ImageNet.
URL: https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet"""
super(bvlc_googlenet, self).__init__()

self.model = GoogLeNet(num_classes=1000)

self.input_size = input_size
self.input_mean = (104.0, 117.0, 123.0)

def forward(self, frame):
x = cv2.resize(frame,
(self.input_size, self.input_size)).astype(np.float32)
x = (x - self.input_mean).astype(np.float32)
x = np.transpose(x, [2, 0, 1])

x = np.expand_dims(x, 0)
x = torch.from_numpy(x)
if not next(self.model.parameters()).device.type == 'cpu':
x = x.cuda()
with torch.no_grad():
frame_feat = self.model(x)
if not frame_feat.device.type == 'cpu':
frame_feat = frame_feat.cpu()
frame_feat = frame_feat.numpy()
frame_feat = frame_feat / np.linalg.norm(frame_feat)
return frame_feat.reshape(-1)

+ 0
- 0
modelscope/models/cv/video_summarization/kts/__init__.py View File


+ 35
- 0
modelscope/models/cv/video_summarization/kts/cpd_auto.py View File

@@ -0,0 +1,35 @@
# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS.

import numpy as np

from .cpd_nonlin import cpd_nonlin


def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs):
"""Detect change points automatically selecting their number

:param K: Kernel between each pair of frames in video
:param ncp: Maximum number of change points
:param vmax: Special parameter
:param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x
:param kwargs: Extra parameters for ``cpd_nonlin``
:return: Tuple (cps, costs)
- cps - best selected change-points
- costs - costs for 0,1,2,...,m change-points
"""
m = ncp
_, scores = cpd_nonlin(K, m, backtrack=False, **kwargs)

N = K.shape[0]
N2 = N * desc_rate # length of the video before down-sampling

penalties = np.zeros(m + 1)
# Prevent division by zero (in case of 0 changes)
ncp = np.arange(1, m + 1)
penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1)

costs = scores / float(N) + penalties
m_best = np.argmin(costs)
cps, scores2 = cpd_nonlin(K, m_best, **kwargs)

return cps, scores2

+ 102
- 0
modelscope/models/cv/video_summarization/kts/cpd_nonlin.py View File

@@ -0,0 +1,102 @@
# The implementation is based on KTS, available at https://github.com/TatsuyaShirakawa/KTS.

import numpy as np


def calc_scatters(K):
"""Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with
starting frame i and ending frame j}
"""
n = K.shape[0]
K1 = np.cumsum([0] + list(np.diag(K)))
K2 = np.zeros((n + 1, n + 1))
# TODO: use the fact that K - symmetric
K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1)

diagK2 = np.diag(K2)

i = np.arange(n).reshape((-1, 1))
j = np.arange(n).reshape((1, -1))

ij_f32 = ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32))
diagK2_K2 = (
diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape(
(-1, 1)) - K2[1:, :-1].T - K2[:-1, 1:])
scatters = (
K1[1:].reshape((1, -1)) - K1[:-1].reshape(
(-1, 1)) - diagK2_K2 / ij_f32)

scatters[j < i] = 0

return scatters


def cpd_nonlin(K,
ncp,
lmin=1,
lmax=100000,
backtrack=True,
verbose=True,
out_scatters=None):
"""Change point detection with dynamic programming

:param K: Square kernel matrix
:param ncp: Number of change points to detect (ncp >= 0)
:param lmin: Minimal length of a segment
:param lmax: Maximal length of a segment
:param backtrack: If False - only evaluate objective scores (to save memory)
:param verbose: If true, print verbose message
:param out_scatters: Output scatters
:return: Tuple (cps, obj_vals)
- cps - detected array of change points: mean is thought to be constant
on [ cps[i], cps[i+1] )
- obj_vals - values of the objective function for 0..m changepoints
"""
m = int(ncp) # prevent numpy.int64

n, n1 = K.shape
assert n == n1, 'Kernel matrix awaited.'
assert (m + 1) * lmin <= n <= (m + 1) * lmax
assert 1 <= lmin <= lmax

if verbose:
print('Precomputing scatters...')
J = calc_scatters(K)

if out_scatters is not None:
out_scatters[0] = J

if verbose:
print('Inferring best change points...')
# Iden[k, l] - value of the objective for k change-points and l first frames
Iden = 1e101 * np.ones((m + 1, n + 1))
Iden[0, lmin:lmax] = J[0, lmin - 1:lmax - 1]

if backtrack:
# p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l
p = np.zeros((m + 1, n + 1), dtype=int)
else:
p = np.zeros((1, 1), dtype=int)

for k in range(1, m + 1):
for l_frame in range((k + 1) * lmin, n + 1):
tmin = max(k * lmin, l_frame - lmax)
tmax = l_frame - lmin + 1
c = J[tmin:tmax, l_frame - 1].reshape(-1) + \
Iden[k - 1, tmin:tmax].reshape(-1)
Iden[k, l_frame] = np.min(c)
if backtrack:
p[k, l_frame] = np.argmin(c) + tmin

# Collect change points
cps = np.zeros(m, dtype=int)

if backtrack:
cur = n
for k in range(m, 0, -1):
cps[k - 1] = p[k, cur]
cur = cps[k - 1]

scores = Iden[:, n].copy()
scores[scores > 1e99] = np.inf
return cps, scores

+ 311
- 0
modelscope/models/cv/video_summarization/pgl_sum.py View File

@@ -0,0 +1,311 @@
# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):

def __init__(self,
input_size=1024,
output_size=1024,
freq=10000,
heads=1,
pos_enc=None):
""" The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V

:param int input_size: Feature input size of Q, K, V.
:param int output_size: Feature -hidden- size of Q, K, V.
:param int freq: The frequency of the sinusoidal positional encoding.
:param int heads: Number of heads for the attention module.
:param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative].
"""
super(SelfAttention, self).__init__()

self.permitted_encodings = ['absolute', 'relative']
if pos_enc is not None:
pos_enc = pos_enc.lower()
assert pos_enc in self.permitted_encodings, f'Supported encodings: {*self.permitted_encodings,}'

self.input_size = input_size
self.output_size = output_size
self.heads = heads
self.pos_enc = pos_enc
self.freq = freq
self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList(
), nn.ModuleList()
for _ in range(self.heads):
self.Wk.append(
nn.Linear(
in_features=input_size,
out_features=output_size // heads,
bias=False))
self.Wq.append(
nn.Linear(
in_features=input_size,
out_features=output_size // heads,
bias=False))
self.Wv.append(
nn.Linear(
in_features=input_size,
out_features=output_size // heads,
bias=False))
self.out = nn.Linear(
in_features=output_size, out_features=input_size, bias=False)

self.softmax = nn.Softmax(dim=-1)
self.drop = nn.Dropout(p=0.5)

def getAbsolutePosition(self, T):
"""Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)

:param int T: Number of frames contained in Q, K and V
:return: Tensor with shape [T, T]
"""
freq = self.freq
d = self.input_size

pos = torch.tensor([k for k in range(T)],
device=self.out.weight.device)
i = torch.tensor([k for k in range(T // 2)],
device=self.out.weight.device)

# Reshape tensors each pos_k for each i indices
pos = pos.reshape(pos.shape[0], 1)
pos = pos.repeat_interleave(i.shape[0], dim=1)
i = i.repeat(pos.shape[0], 1)

AP = torch.zeros(T, T, device=self.out.weight.device)
AP[pos, 2 * i] = torch.sin(pos / freq**((2 * i) / d))
AP[pos, 2 * i + 1] = torch.cos(pos / freq**((2 * i) / d))
return AP

def getRelativePosition(self, T):
"""Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
r_pos calculations as here: https://theaisummer.com/positional-embeddings/

:param int T: Number of frames contained in Q, K and V
:return: Tensor with shape [T, T]
"""
freq = self.freq
d = 2 * T
min_rpos = -(T - 1)

i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
j = torch.tensor([k for k in range(T)], device=self.out.weight.device)

# Reshape tensors each i for each j indices
i = i.reshape(i.shape[0], 1)
i = i.repeat_interleave(i.shape[0], dim=1)
j = j.repeat(i.shape[0], 1)

# Calculate the relative positions
r_pos = j - i - min_rpos

RP = torch.zeros(T, T, device=self.out.weight.device)
idx = torch.tensor([k for k in range(T // 2)],
device=self.out.weight.device)
RP[:, 2 * idx] = torch.sin(
r_pos[:, 2 * idx] / freq**((i[:, 2 * idx] + j[:, 2 * idx]) / d))
RP[:, 2 * idx + 1] = torch.cos(
r_pos[:, 2 * idx + 1]
/ freq**((i[:, 2 * idx + 1] + j[:, 2 * idx + 1]) / d))
return RP

def forward(self, x):
""" Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism.

:param torch.tensor x: Frame features with shape [T, input_size]
:return: A tuple of:
y: Weighted features based on the attention weights, with shape [T, input_size]
att_weights : The attention weights (before dropout), with shape [T, T]
"""
outputs = []
for head in range(self.heads):
K = self.Wk[head](x)
Q = self.Wq[head](x)
V = self.Wv[head](x)

# Q *= 0.06 # scale factor VASNet
# Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) )
energies = torch.matmul(Q, K.transpose(1, 0))
if self.pos_enc is not None:
if self.pos_enc == 'absolute':
AP = self.getAbsolutePosition(T=energies.shape[0])
energies = energies + AP
elif self.pos_enc == 'relative':
RP = self.getRelativePosition(T=energies.shape[0])
energies = energies + RP

att_weights = self.softmax(energies)
_att_weights = self.drop(att_weights)
y = torch.matmul(_att_weights, V)

# Save the current head output
outputs.append(y)
y = self.out(torch.cat(outputs, dim=1))
return y, att_weights.clone(
) # for now we don't deal with the weights (probably max or avg pooling)


class MultiAttention(nn.Module):

def __init__(self,
input_size=1024,
output_size=1024,
freq=10000,
pos_enc=None,
num_segments=None,
heads=1,
fusion=None):
""" Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters.

:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(MultiAttention, self).__init__()

# Global Attention, considering differences among all frames
self.attention = SelfAttention(
input_size=input_size,
output_size=output_size,
freq=freq,
pos_enc=pos_enc,
heads=heads)

self.num_segments = num_segments
if self.num_segments is not None:
assert self.num_segments >= 2, 'num_segments must be None or 2+'
self.local_attention = nn.ModuleList()
for _ in range(self.num_segments):
# Local Attention, considering differences among the same segment with reduce hidden size
self.local_attention.append(
SelfAttention(
input_size=input_size,
output_size=output_size // num_segments,
freq=freq,
pos_enc=pos_enc,
heads=4))
self.permitted_fusions = ['add', 'mult', 'avg', 'max']
self.fusion = fusion
if self.fusion is not None:
self.fusion = self.fusion.lower()
assert self.fusion in self.permitted_fusions, f'Fusion method must be: {*self.permitted_fusions,}'

def forward(self, x):
""" Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.

:param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
:return: A tuple of:
weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
weighted_value, attn_weights = self.attention(x) # global attention

if self.num_segments is not None and self.fusion is not None:
segment_size = math.ceil(x.shape[0] / self.num_segments)
for segment in range(self.num_segments):
left_pos = segment * segment_size
right_pos = (segment + 1) * segment_size
local_x = x[left_pos:right_pos]
weighted_local_value, attn_local_weights = self.local_attention[
segment](local_x) # local attentions

# Normalize the features vectors
weighted_value[left_pos:right_pos] = F.normalize(
weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
weighted_local_value = F.normalize(
weighted_local_value, p=2, dim=1)
if self.fusion == 'add':
weighted_value[left_pos:right_pos] += weighted_local_value
elif self.fusion == 'mult':
weighted_value[left_pos:right_pos] *= weighted_local_value
elif self.fusion == 'avg':
weighted_value[left_pos:right_pos] += weighted_local_value
weighted_value[left_pos:right_pos] /= 2
elif self.fusion == 'max':
weighted_value[left_pos:right_pos] = torch.max(
weighted_value[left_pos:right_pos].clone(),
weighted_local_value)

return weighted_value, attn_weights


class PGL_SUM(nn.Module):

def __init__(self,
input_size=1024,
output_size=1024,
freq=10000,
pos_enc=None,
num_segments=None,
heads=1,
fusion=None):
""" Class wrapping the PGL-SUM model; its key modules and parameters.

:param int input_size: The expected input feature size.
:param int output_size: The hidden feature size of the attention mechanisms.
:param int freq: The frequency of the sinusoidal positional encoding.
:param None | str pos_enc: The selected positional encoding [absolute, relative].
:param None | int num_segments: The selected number of segments to split the videos.
:param int heads: The selected number of global heads.
:param None | str fusion: The selected type of feature fusion.
"""
super(PGL_SUM, self).__init__()

self.attention = MultiAttention(
input_size=input_size,
output_size=output_size,
freq=freq,
pos_enc=pos_enc,
num_segments=num_segments,
heads=heads,
fusion=fusion)
self.linear_1 = nn.Linear(
in_features=input_size, out_features=input_size)
self.linear_2 = nn.Linear(
in_features=self.linear_1.out_features, out_features=1)

self.drop = nn.Dropout(p=0.5)
self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6)
self.norm_linear = nn.LayerNorm(
normalized_shape=self.linear_1.out_features, eps=1e-6)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def forward(self, frame_features):
""" Produce frames importance scores from the frame features, using the PGL-SUM model.

:param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by
using the pool5 layer of GoogleNet.
:return: A tuple of:
y: Tensor with shape [1, T] containing the frames importance scores in [0, 1].
attn_weights: Tensor with shape [T, T] containing the attention weights.
"""
frame_features = frame_features.reshape(-1, frame_features.shape[-1])
residual = frame_features
weighted_value, attn_weights = self.attention(frame_features)
y = weighted_value + residual
y = self.drop(y)
y = self.norm_y(y)

# 2-layer NN (Regressor Network)
y = self.linear_1(y)
y = self.relu(y)
y = self.drop(y)
y = self.norm_linear(y)

y = self.linear_2(y)
y = self.sigmoid(y)
y = y.view(1, -1)

return y, attn_weights

+ 224
- 0
modelscope/models/cv/video_summarization/summarizer.py View File

@@ -0,0 +1,224 @@
# The implementation is based on PGL-SUM, available at https://github.com/e-apostolidis/PGL-SUM.

import os.path as osp
from copy import deepcopy
from typing import Dict, Union

import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.video_summarization.kts.cpd_auto import cpd_auto
from modelscope.models.cv.video_summarization.pgl_sum import PGL_SUM
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


def get_change_points(video_feat, n_frame):
video_feat = np.array(video_feat, np.float32)
K = np.dot(video_feat, video_feat.T)
change_points, _ = cpd_auto(K, ncp=120, vmax=2.2 / 4.0, lmin=1)
change_points = change_points * 15
change_points = np.concatenate(([0], change_points, [n_frame - 1]))

temp_change_points = []
for idx in range(len(change_points) - 1):
segment = [change_points[idx], change_points[idx + 1] - 1]
if idx == len(change_points) - 2:
segment = [change_points[idx], change_points[idx + 1]]

temp_change_points.append(segment)
change_points = np.array(list(temp_change_points))

temp_n_frame_per_seg = []
for change_points_idx in range(len(change_points)):
n_frame = change_points[change_points_idx][1] - change_points[
change_points_idx][0]
temp_n_frame_per_seg.append(n_frame)
n_frame_per_seg = np.array(list(temp_n_frame_per_seg))

return change_points, n_frame_per_seg


def knap_sack(W, wt, val, n):
""" Maximize the value that a knapsack of capacity W can hold. You can either put the item or discard it, there is
no concept of putting some part of item in the knapsack.

:param int W: Maximum capacity -in frames- of the knapsack.
:param list[int] wt: The weights (lengths -in frames-) of each video shot.
:param list[float] val: The values (importance scores) of each video shot.
:param int n: The number of the shots.
:return: A list containing the indices of the selected shots.
"""
K = [[0 for _ in range(W + 1)] for _ in range(n + 1)]

# Build table K[][] in bottom up manner
for i in range(n + 1):
for w in range(W + 1):
if i == 0 or w == 0:
K[i][w] = 0
elif wt[i - 1] <= w:
K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]],
K[i - 1][w])
else:
K[i][w] = K[i - 1][w]

selected = []
w = W
for i in range(n, 0, -1):
if K[i][w] != K[i - 1][w]:
selected.insert(0, i - 1)
w -= wt[i - 1]

return selected


def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions):
""" Generate the automatic machine summary, based on the video shots; the frame importance scores; the number of
frames in the original video and the position of the sub-sampled frames of the original video.

:param list[np.ndarray] all_shot_bound: The video shots for all the -original- testing videos.
:param list[np.ndarray] all_scores: The calculated frame importance scores for all the sub-sampled testing videos.
:param list[np.ndarray] all_nframes: The number of frames for all the -original- testing videos.
:param list[np.ndarray] all_positions: The position of the sub-sampled frames for all the -original- testing videos.
:return: A list containing the indices of the selected frames for all the -original- testing videos.
"""
all_summaries = []
for video_index in range(len(all_scores)):
# Get shots' boundaries
shot_bound = all_shot_bound[video_index] # [number_of_shots, 2]
frame_init_scores = all_scores[video_index]
n_frames = all_nframes[video_index]
positions = all_positions[video_index]

# Compute the importance scores for the initial frame sequence (not the sub-sampled one)
frame_scores = np.zeros(n_frames, dtype=np.float32)
if positions.dtype != int:
positions = positions.astype(np.int32)
if positions[-1] != n_frames:
positions = np.concatenate([positions, [n_frames]])
for i in range(len(positions) - 1):
pos_left, pos_right = positions[i], positions[i + 1]
if i == len(frame_init_scores):
frame_scores[pos_left:pos_right] = 0
else:
frame_scores[pos_left:pos_right] = frame_init_scores[i]

# Compute shot-level importance scores by taking the average importance scores of all frames in the shot
shot_imp_scores = []
shot_lengths = []
for shot in shot_bound:
shot_lengths.append(shot[1] - shot[0] + 1)
shot_imp_scores.append(
(frame_scores[shot[0]:shot[1] + 1].mean()).item())

# Select the best shots using the knapsack implementation
final_shot = shot_bound[-1]
final_max_length = int((final_shot[1] + 1) * 0.15)

selected = knap_sack(final_max_length, shot_lengths, shot_imp_scores,
len(shot_lengths))

# Select all frames from each selected shot (by setting their value in the summary vector to 1)
summary = np.zeros(final_shot[1] + 1, dtype=np.int8)
for shot in selected:
summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1

all_summaries.append(summary)

return all_summaries


@MODELS.register_module(
Tasks.video_summarization, module_name=Models.video_summarization)
class PGLVideoSummarization(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the video summarization model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)

self.loss = nn.MSELoss()
self.model = PGL_SUM(
input_size=1024,
output_size=1024,
num_segments=4,
heads=8,
fusion='add',
pos_enc='absolute')
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model = self.model.to(self._device)

self.model = self.load_pretrained(self.model, model_path)

if self.training:
self.model.train()
else:
self.model.eval()

def load_pretrained(self, net, load_path, strict=True, param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(
load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info(
f'Loading: {param_key} does not exist, use params.')
if param_key in load_net:
load_net = load_net[param_key]
logger.info(
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
logger.info('load model done.')
return net

def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
frame_features = input['frame_features']
gtscore = input['gtscore']
preds, attn_weights = self.model(frame_features)
return {'loss': self.loss(preds, gtscore)}

def _inference_forward(self, input: Dict[str,
Tensor]) -> Dict[str, Tensor]:
frame_features = input['frame_features']
y, attn_weights = self.model(frame_features)
return {'scores': y}

def forward(self, input: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Union[list, Tensor]]: results
"""
for key, value in input.items():
input[key] = input[key].to(self._device)
if self.training:
return self._train_forward(input)
else:
return self._inference_forward(input)

+ 3
- 2
modelscope/msdatasets/task_datasets/__init__.py View File

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from .torch_base_dataset import TorchTaskDataset from .torch_base_dataset import TorchTaskDataset
from .veco_dataset import VecoDataset from .veco_dataset import VecoDataset
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
from .video_summarization_dataset import VideoSummarizationDataset
else: else:
_import_structure = { _import_structure = {
'base': ['TaskDataset'], 'base': ['TaskDataset'],
@@ -17,7 +17,8 @@ else:
'torch_base_dataset': ['TorchTaskDataset'], 'torch_base_dataset': ['TorchTaskDataset'],
'veco_dataset': ['VecoDataset'], 'veco_dataset': ['VecoDataset'],
'image_instance_segmentation_coco_dataset': 'image_instance_segmentation_coco_dataset':
['ImageInstanceSegmentationCocoDataset']
['ImageInstanceSegmentationCocoDataset'],
'video_summarization_dataset': ['VideoSummarizationDataset'],
} }
import sys import sys




+ 69
- 0
modelscope/msdatasets/task_datasets/video_summarization_dataset.py View File

@@ -0,0 +1,69 @@
import os

import h5py
import json
import numpy as np
import torch

from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset


class VideoSummarizationDataset(TorchTaskDataset):

def __init__(self, mode, opt, root_dir):
self.mode = mode
self.data_filename = os.path.join(root_dir, opt.dataset_file)
self.split_filename = os.path.join(root_dir, opt.split_file)
self.split_index = opt.split_index # it represents the current split (varies from 0 to 4)
hdf = h5py.File(self.data_filename, 'r')
self.list_frame_features, self.list_gtscores = [], []
self.list_user_summary = []
self.list_change_points = []
self.list_n_frames = []
self.list_positions = []

with open(self.split_filename) as f:
data = json.loads(f.read())
for i, split in enumerate(data):
if i == self.split_index:
self.split = split
break

for video_name in self.split[self.mode + '_keys']:
frame_features = torch.Tensor(
np.array(hdf[video_name + '/features']))
gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore']))
user_summary = np.array(hdf[f'{video_name}/user_summary'])
change_points = np.array(hdf[f'{video_name}/change_points'])
n_frames = np.array(hdf[f'{video_name}/n_frames'])
positions = np.array(hdf[f'{video_name}/picks'])

self.list_frame_features.append(frame_features)
self.list_gtscores.append(gtscore)
self.list_user_summary.append(user_summary)
self.list_change_points.append(change_points)
self.list_n_frames.append(n_frames)
self.list_positions.append(positions)

hdf.close()

def __len__(self):
self.len = len(self.split[self.mode + '_keys'])
return self.len

def __getitem__(self, index):
frame_features = self.list_frame_features[index]
gtscore = self.list_gtscores[index]
user_summary = self.list_user_summary[index]
change_points = self.list_change_points[index]
n_frames = self.list_n_frames[index]
positions = self.list_positions[index]

return dict(
frame_features=frame_features,
gtscore=gtscore,
user_summary=user_summary,
change_points=change_points,
n_frames=n_frames,
positions=positions)

+ 109
- 0
modelscope/pipelines/cv/video_summarization_pipeline.py View File

@@ -0,0 +1,109 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import torch
from tqdm import tqdm

from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_summarization import PGLVideoSummarization
from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet
from modelscope.models.cv.video_summarization.summarizer import (
generate_summary, get_change_points)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_summarization, module_name=Pipelines.video_summarization)
class VideoSummarizationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a video summarization pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)
logger.info(f'loading model from {model}')
googlenet_model_path = osp.join(model, 'bvlc_googlenet.pt')
config_path = osp.join(model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)

self.googlenet_model = bvlc_googlenet()
self.googlenet_model.model.load_state_dict(
torch.load(
googlenet_model_path, map_location=torch.device(self.device)))
self.googlenet_model = self.googlenet_model.to(self.device).eval()

self.pgl_model = PGLVideoSummarization(model)
self.pgl_model = self.pgl_model.to(self.device).eval()

logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if not isinstance(input, str):
raise TypeError(f'input should be a str,'
f' but got {type(input)}')
frames = []
picks = []
cap = cv2.VideoCapture(input)
frame_idx = 0
while (cap.isOpened()):
ret, frame = cap.read()
if not ret:
break
if frame_idx % 15 == 0:
frames.append(frame)
picks.append(frame_idx)
frame_idx += 1
n_frame = frame_idx

result = {
'video_name': input,
'video_frames': np.array(frames),
'n_frame': n_frame,
'picks': np.array(picks)
}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

frame_features = []
for frame in tqdm(input['video_frames']):
feat = self.googlenet_model(frame)
frame_features.append(feat)

change_points, n_frame_per_seg = get_change_points(
frame_features, input['n_frame'])

summary = self.inference(frame_features, input['n_frame'],
input['picks'], change_points)

return {OutputKeys.OUTPUT: summary}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

def inference(self, frame_features, n_frames, picks, change_points):
frame_features = torch.from_numpy(np.array(frame_features, np.float32))
picks = np.array(picks, np.int32)

with torch.no_grad():
results = self.pgl_model(dict(frame_features=frame_features))
scores = results['scores']
if not scores.device.type == 'cpu':
scores = scores.cpu()
scores = scores.squeeze(0).numpy().tolist()
summary = generate_summary([change_points], [scores], [n_frames],
[picks])[0]

return summary

+ 25
- 0
modelscope/preprocessors/image.py View File

@@ -264,3 +264,28 @@ class ImageInstanceSegmentationPreprocessor(Preprocessor):
return None return None


return results return results


@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.video_summarization_preprocessor)
class VideoSummarizationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
self.model_dir: str = model_dir

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""process the raw input data

Args:
data Dict[str, Any]

Returns:
Dict[str, Any]: the preprocessed data
"""
return data

+ 1
- 0
modelscope/utils/constant.py View File

@@ -64,6 +64,7 @@ class CVTasks(object):


# reid and tracking # reid and tracking
video_single_object_tracking = 'video-single-object-tracking' video_single_object_tracking = 'video-single-object-tracking'
video_summarization = 'video-summarization'
image_reid_person = 'image-reid-person' image_reid_person = 'image-reid-person'






+ 32
- 0
tests/pipelines/test_video_summarization.py View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class VideoSummarizationTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):

summarization_pipeline = pipeline(
Tasks.video_summarization,
model='damo/cv_googlenet_pgl-video-summarization')
result = summarization_pipeline(
'data/test/videos/video_category_test_video.mp4')

print(f'video summarization output: {result}.')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub_default_model(self):
summarization_pipeline = pipeline(Tasks.video_summarization)
result = summarization_pipeline(
'data/test/videos/video_category_test_video.mp4')

print(f'video summarization output: {result}.')


if __name__ == '__main__':
unittest.main()

+ 75
- 0
tests/trainers/test_video_summarization_trainer.py View File

@@ -0,0 +1,75 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.video_summarization import PGLVideoSummarization
from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


class VideoSummarizationTrainerTest(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

self.model_id = 'damo/cv_googlenet_pgl-video-summarization'
self.cache_path = snapshot_download(self.model_id)
self.config = Config.from_file(
os.path.join(self.cache_path, ModelFile.CONFIGURATION))
self.dataset_train = VideoSummarizationDataset('train',
self.config.dataset,
self.cache_path)
self.dataset_val = VideoSummarizationDataset('test',
self.config.dataset,
self.cache_path)

def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
work_dir=self.tmp_dir)
trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
model = PGLVideoSummarization.from_pretrained(self.cache_path)
kwargs = dict(
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
max_epochs=2,
work_dir=self.tmp_dir)
trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save