shuying.shu yingda.chen 3 years ago
parent
commit
a9deb3895c
34 changed files with 1870 additions and 7 deletions
  1. +3
    -0
      data/test/videos/movie_scene_segmentation_test_video.mp4
  2. +6
    -0
      modelscope/metainfo.py
  3. +2
    -0
      modelscope/metrics/__init__.py
  4. +1
    -0
      modelscope/metrics/builder.py
  5. +52
    -0
      modelscope/metrics/movie_scene_segmentation_metric.py
  6. +3
    -2
      modelscope/models/cv/__init__.py
  7. +25
    -0
      modelscope/models/cv/movie_scene_segmentation/__init__.py
  8. +45
    -0
      modelscope/models/cv/movie_scene_segmentation/get_model.py
  9. +192
    -0
      modelscope/models/cv/movie_scene_segmentation/model.py
  10. +3
    -0
      modelscope/models/cv/movie_scene_segmentation/utils/__init__.py
  11. +29
    -0
      modelscope/models/cv/movie_scene_segmentation/utils/head.py
  12. +118
    -0
      modelscope/models/cv/movie_scene_segmentation/utils/save_op.py
  13. +331
    -0
      modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py
  14. +132
    -0
      modelscope/models/cv/movie_scene_segmentation/utils/trn.py
  15. +3
    -0
      modelscope/msdatasets/task_datasets/__init__.py
  16. +1
    -0
      modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py
  17. +173
    -0
      modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
  18. +102
    -0
      modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py
  19. +18
    -0
      modelscope/outputs.py
  20. +3
    -0
      modelscope/pipelines/builder.py
  21. +5
    -1
      modelscope/pipelines/cv/__init__.py
  22. +67
    -0
      modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py
  23. +2
    -2
      modelscope/preprocessors/__init__.py
  24. +19
    -0
      modelscope/preprocessors/movie_scene_segmentation/__init__.py
  25. +312
    -0
      modelscope/preprocessors/movie_scene_segmentation/transforms.py
  26. +45
    -0
      modelscope/preprocessors/video.py
  27. +3
    -2
      modelscope/trainers/__init__.py
  28. +2
    -0
      modelscope/trainers/cv/__init__.py
  29. +20
    -0
      modelscope/trainers/cv/movie_scene_segmentation_trainer.py
  30. +1
    -0
      modelscope/utils/constant.py
  31. +1
    -0
      requirements/cv.txt
  32. +6
    -0
      tests/msdatasets/test_ms_dataset.py
  33. +36
    -0
      tests/pipelines/test_movie_scene_segmentation.py
  34. +109
    -0
      tests/trainers/test_movie_scene_segmentation_trainer.py

+ 3
- 0
data/test/videos/movie_scene_segmentation_test_video.mp4 View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:59fa397b01dc4c9b67a19ca42f149287b9c4e7b2158aba5d07d2db88af87b23f
size 126815483

+ 6
- 0
modelscope/metainfo.py View File

@@ -27,6 +27,7 @@ class Models(object):
video_summarization = 'pgl-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
resnet50_bert = 'resnet50-bert'

# EasyCV models
yolox = 'YOLOX'
@@ -133,6 +134,7 @@ class Pipelines(object):
video_summarization = 'googlenet_pgl_video_summarization'
image_semantic_segmentation = 'image-semantic-segmentation'
image_reid_person = 'passvitb-image-reid-person'
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'

# nlp tasks
sentence_similarity = 'sentence-similarity'
@@ -195,6 +197,7 @@ class Trainers(object):
image_instance_segmentation = 'image-instance-segmentation'
image_portrait_enhancement = 'image-portrait-enhancement'
video_summarization = 'video-summarization'
movie_scene_segmentation = 'movie-scene-segmentation'

# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
@@ -223,6 +226,7 @@ class Preprocessors(object):
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'

# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -279,6 +283,8 @@ class Metrics(object):
# metrics for image-portrait-enhancement task
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric'
video_summarization_metric = 'video-summarization-metric'
# metric for movie-scene-segmentation task
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric'


class Optimizers(object):


+ 2
- 0
modelscope/metrics/__init__.py View File

@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from .text_generation_metric import TextGenerationMetric
from .token_classification_metric import TokenClassificationMetric
from .video_summarization_metric import VideoSummarizationMetric
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric

else:
_import_structure = {
@@ -32,6 +33,7 @@ else:
'text_generation_metric': ['TextGenerationMetric'],
'token_classification_metric': ['TokenClassificationMetric'],
'video_summarization_metric': ['VideoSummarizationMetric'],
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'],
}

import sys


+ 1
- 0
modelscope/metrics/builder.py View File

@@ -34,6 +34,7 @@ task_default_metrics = {
Tasks.video_summarization: [Metrics.video_summarization_metric],
Tasks.image_captioning: [Metrics.text_gen_metric],
Tasks.visual_question_answering: [Metrics.text_gen_metric],
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric],
}




+ 52
- 0
modelscope/metrics/movie_scene_segmentation_metric.py View File

@@ -0,0 +1,52 @@
from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group,
module_name=Metrics.movie_scene_segmentation_metric)
class MovieSceneSegmentationMetric(Metric):
"""The metric computation class for movie scene segmentation classes.
"""

def __init__(self):
self.preds = []
self.labels = []
self.eps = 1e-5

def add(self, outputs: Dict, inputs: Dict):
preds = outputs['pred']
labels = inputs['label']
self.preds.extend(preds)
self.labels.extend(labels)

def evaluate(self):
gts = np.array(torch_nested_numpify(torch_nested_detach(self.labels)))
prob = np.array(torch_nested_numpify(torch_nested_detach(self.preds)))

gt_one = gts == 1
gt_zero = gts == 0
pred_one = prob == 1
pred_zero = prob == 0

tp = (gt_one * pred_one).sum()
fp = (gt_zero * pred_one).sum()
fn = (gt_one * pred_zero).sum()

precision = 100.0 * tp / (tp + fp + self.eps)
recall = 100.0 * tp / (tp + fn + self.eps)
f1 = 2 * precision * recall / (precision + recall)

return {
MetricKeys.F1: f1,
MetricKeys.RECALL: recall,
MetricKeys.PRECISION: precision
}

+ 3
- 2
modelscope/models/cv/__init__.py View File

@@ -9,8 +9,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_panoptic_segmentation, image_portrait_enhancement,
image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation,
object_detection, product_retrieval_embedding,
realtime_object_detection, salient_detection, super_resolution,
movie_scene_segmentation, object_detection,
product_retrieval_embedding, realtime_object_detection,
salient_detection, super_resolution,
video_single_object_tracking, video_summarization, virual_tryon)

# yapf: enable

+ 25
- 0
modelscope/models/cv/movie_scene_segmentation/__init__.py View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .model import MovieSceneSegmentationModel
from .datasets import MovieSceneSegmentationDataset

else:
_import_structure = {
'model': ['MovieSceneSegmentationModel'],
'datasets': ['MovieSceneSegmentationDataset'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 45
- 0
modelscope/models/cv/movie_scene_segmentation/get_model.py View File

@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------------------
# BaSSL
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# Github: https://github.com/kakaobrain/bassl
# ------------------------------------------------------------------------------------

from .utils.shot_encoder import resnet50
from .utils.trn import TransformerCRN


def get_shot_encoder(cfg):
name = cfg['model']['shot_encoder']['name']
shot_encoder_args = cfg['model']['shot_encoder'][name]
if name == 'resnet':
depth = shot_encoder_args['depth']
if depth == 50:
shot_encoder = resnet50(**shot_encoder_args['params'], )
else:
raise NotImplementedError
else:
raise NotImplementedError

return shot_encoder


def get_contextual_relation_network(cfg):
crn = None

if cfg['model']['contextual_relation_network']['enabled']:
name = cfg['model']['contextual_relation_network']['name']
crn_args = cfg['model']['contextual_relation_network']['params'][name]
if name == 'trn':
sampling_name = cfg['model']['loss']['sampling_method']['name']
crn_args['neighbor_size'] = (
2 * cfg['model']['loss']['sampling_method']['params']
[sampling_name]['neighbor_size'])
crn = TransformerCRN(crn_args)
else:
raise NotImplementedError

return crn


__all__ = ['get_shot_encoder', 'get_contextual_relation_network']

+ 192
- 0
modelscope/models/cv/movie_scene_segmentation/model.py View File

@@ -0,0 +1,192 @@
import os
import os.path as osp
from typing import Any, Dict

import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as TF
from PIL import Image
from shotdetect_scenedetect_lgss import shot_detect

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .get_model import get_contextual_relation_network, get_shot_encoder
from .utils.save_op import get_pred_boundary, pred2scene, scene2video

logger = get_logger()


@MODELS.register_module(
Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
class MovieSceneSegmentationModel(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, *args, **kwargs)

model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
params = torch.load(model_path, map_location='cpu')

config_path = osp.join(model_dir, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(config_path)

def load_param_with_prefix(prefix, model, src_params):
own_state = model.state_dict()
for name, param in own_state.items():
src_name = prefix + '.' + name
own_state[name] = src_params[src_name]

model.load_state_dict(own_state)

self.shot_encoder = get_shot_encoder(self.cfg)
load_param_with_prefix('shot_encoder', self.shot_encoder, params)
self.crn = get_contextual_relation_network(self.cfg)
load_param_with_prefix('crn', self.crn, params)

crn_name = self.cfg.model.contextual_relation_network.name
hdim = self.cfg.model.contextual_relation_network.params[crn_name][
'hidden_size']
self.head_sbd = nn.Linear(hdim, 2)
load_param_with_prefix('head_sbd', self.head_sbd, params)

self.test_transform = TF.Compose([
TF.Resize(size=256, interpolation=Image.BICUBIC),
TF.CenterCrop(224),
TF.ToTensor(),
TF.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

self.infer_result = {'vid': [], 'sid': [], 'pred': []}
sampling_method = self.cfg.dataset.sampling_method.name
self.neighbor_size = self.cfg.dataset.sampling_method.params[
sampling_method].neighbor_size

self.eps = 1e-5

def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
data = inputs['video']
labels = inputs['label']
outputs = self.shared_step(data)

loss = F.cross_entropy(
outputs.squeeze(), labels.squeeze(), reduction='none')
lpos = labels == 1
lneg = labels == 0

pp, nn = 1, 1
wp = (pp / float(pp + nn)) * lpos / (lpos.sum() + self.eps)
wn = (nn / float(pp + nn)) * lneg / (lneg.sum() + self.eps)
w = wp + wn
loss = (w * loss).sum()

probs = torch.argmax(outputs, dim=1)

re = dict(pred=probs, loss=loss)
return re

def inference(self, batch):
logger.info('Begin scene detect ......')
bs = self.cfg.pipeline.batch_size_per_gpu
sids = batch['sid']
inputs = batch['shot_feat']

shot_num = len(sids)
cnt = shot_num // bs + 1

for i in range(cnt):
start = i * bs
end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num
input_ = inputs[start:end]
sid_ = sids[start:end]
input_ = torch.stack(input_)
outputs = self.shared_step(input_) # shape [b,2]
prob = F.softmax(outputs, dim=1)
self.infer_result['sid'].extend(sid_.cpu().detach().numpy())
self.infer_result['pred'].extend(prob[:, 1].cpu().detach().numpy())
self.infer_result['pred'] = np.stack(self.infer_result['pred'])

assert len(self.infer_result['sid']) == len(sids)
assert len(self.infer_result['pred']) == len(inputs)
return self.infer_result

def shared_step(self, inputs):
with torch.no_grad():
# infer shot encoder
shot_repr = self.extract_shot_representation(inputs)
assert len(shot_repr.shape) == 3

# infer CRN
_, pooled = self.crn(shot_repr, mask=None)
# infer boundary score
pred = self.head_sbd(pooled)
return pred

def save_shot_feat(self, _repr):
feat = _repr.float().cpu().numpy()
pth = self.cfg.dataset.img_path + '/features'
os.makedirs(pth)

for idx in range(_repr.shape[0]):
name = f'shot_{str(idx).zfill(4)}.npy'
name = osp.join(pth, name)
np.save(name, feat[idx])

def extract_shot_representation(self,
inputs: torch.Tensor) -> torch.Tensor:
""" inputs [b s k c h w] -> output [b d] """
assert len(inputs.shape) == 6 # (B Shot Keyframe C H W)
b, s, k, c, h, w = inputs.shape
inputs = einops.rearrange(inputs, 'b s k c h w -> (b s) k c h w', s=s)
keyframe_repr = [self.shot_encoder(inputs[:, _k]) for _k in range(k)]
# [k (b s) d] -> [(b s) d]
shot_repr = torch.stack(keyframe_repr).mean(dim=0)

shot_repr = einops.rearrange(shot_repr, '(b s) d -> b s d', s=s)
return shot_repr

def postprocess(self, inputs: Dict[str, Any], **kwargs):
logger.info('Generate scene .......')

pred_dict = inputs['feat']
thres = self.cfg.pipeline.save_threshold

anno_dict = get_pred_boundary(pred_dict, thres)
scene_dict, scene_list = pred2scene(self.shot2keyf, anno_dict)
if self.cfg.pipeline.save_split_scene:
re_dir = scene2video(inputs['input_video_pth'], scene_list, thres)
print(f'Split scene video saved to {re_dir}')
return len(scene_list), scene_dict

def preprocess(self, inputs):
logger.info('Begin shot detect......')
shot_keyf_lst, anno, shot2keyf = shot_detect(
inputs, **self.cfg.preprocessor.shot_detect)
logger.info('Shot detect done!')

single_shot_feat, sid = [], []
for idx, one_shot in enumerate(shot_keyf_lst):
one_shot = [
self.test_transform(one_frame) for one_frame in one_shot
]
one_shot = torch.stack(one_shot, dim=0)
single_shot_feat.append(one_shot)
sid.append(idx)
single_shot_feat = torch.stack(single_shot_feat, dim=0)
shot_feat = []
for idx, one_shot in enumerate(anno):
shot_idx = int(one_shot['shot_id']) + np.arange(
-self.neighbor_size, self.neighbor_size + 1)
shot_idx = np.clip(shot_idx, 0, one_shot['num_shot'])
_one_shot = single_shot_feat[shot_idx]
shot_feat.append(_one_shot)
self.shot2keyf = shot2keyf
self.anno = anno
return shot_feat, sid

+ 3
- 0
modelscope/models/cv/movie_scene_segmentation/utils/__init__.py View File

@@ -0,0 +1,3 @@
from .save_op import get_pred_boundary, pred2scene, scene2video
from .shot_encoder import resnet50
from .trn import TransformerCRN

+ 29
- 0
modelscope/models/cv/movie_scene_segmentation/utils/head.py View File

@@ -0,0 +1,29 @@
# ------------------------------------------------------------------------------------
# BaSSL
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# Github: https://github.com/kakaobrain/bassl
# ------------------------------------------------------------------------------------

import torch.nn as nn
import torch.nn.functional as F


class MlpHead(nn.Module):

def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim

self.model = nn.Sequential(
nn.Linear(self.input_dim, self.hidden_dim, bias=True),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.output_dim, bias=True),
)

def forward(self, x):
# x shape: [b t d] where t means the number of views
x = self.model(x)
return F.normalize(x, dim=-1)

+ 118
- 0
modelscope/models/cv/movie_scene_segmentation/utils/save_op.py View File

@@ -0,0 +1,118 @@
# ----------------------------------------------------------------------------------
# The codes below partially refer to the SceneSeg LGSS.
# Github: https://github.com/AnyiRao/SceneSeg
# ----------------------------------------------------------------------------------
import os
import os.path as osp
import subprocess

import cv2
import numpy as np
from tqdm import tqdm


def get_pred_boundary(pred_dict, threshold=0.5):
pred = pred_dict['pred']
tmp = (pred > threshold).astype(np.int32)
anno_dict = {}
for idx in range(len(tmp)):
anno_dict.update({str(pred_dict['sid'][idx]).zfill(4): int(tmp[idx])})
return anno_dict


def pred2scene(shot2keyf, anno_dict):
scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict)

scene_dict = {}
assert len(scene_list) == len(pair_list)
for scene_ind, scene_item in enumerate(scene_list):
scene_dict.update(
{scene_ind: {
'shot': pair_list[scene_ind],
'frame': scene_item
}})

return scene_dict, scene_list


def scene2video(source_movie_fn, scene_list, thres):

vcap = cv2.VideoCapture(source_movie_fn)
fps = vcap.get(cv2.CAP_PROP_FPS) # video.fps
out_video_dir_fn = os.path.join(os.getcwd(),
f'pred_result/scene_video_{thres}')
os.makedirs(out_video_dir_fn, exist_ok=True)

for scene_ind, scene_item in tqdm(enumerate(scene_list)):
scene = str(scene_ind).zfill(4)
start_frame = int(scene_item[0])
end_frame = int(scene_item[1])
start_time, end_time = start_frame / fps, end_frame / fps
duration_time = end_time - start_time
out_video_fn = os.path.join(out_video_dir_fn,
'scene_{}.mp4'.format(scene))
if os.path.exists(out_video_fn):
continue
call_list = ['ffmpeg']
call_list += ['-v', 'quiet']
call_list += [
'-y', '-ss',
str(start_time), '-t',
str(duration_time), '-i', source_movie_fn
]
call_list += ['-map_chapters', '-1']
call_list += [out_video_fn]
subprocess.call(call_list)
return osp.join(os.getcwd(), 'pred_result')


def get_demo_scene_list(shot2keyf, anno_dict):
pair_list = get_pair_list(anno_dict)

scene_list = []
for pair in pair_list:
start_shot, end_shot = int(pair[0]), int(pair[-1])
start_frame = shot2keyf[start_shot].split(' ')[0]
end_frame = shot2keyf[end_shot].split(' ')[1]
scene_list.append((start_frame, end_frame))
return scene_list, pair_list


def get_pair_list(anno_dict):
sort_anno_dict_key = sorted(anno_dict.keys())
tmp = 0
tmp_list = []
tmp_label_list = []
anno_list = []
anno_label_list = []
for key in sort_anno_dict_key:
value = anno_dict.get(key)
tmp += value
tmp_list.append(key)
tmp_label_list.append(value)
if tmp == 1:
anno_list.append(tmp_list)
anno_label_list.append(tmp_label_list)
tmp = 0
tmp_list = []
tmp_label_list = []
continue
if key == sort_anno_dict_key[-1]:
if len(tmp_list) > 0:
anno_list.append(tmp_list)
anno_label_list.append(tmp_label_list)
if len(anno_list) == 0:
return None
while [] in anno_list:
anno_list.remove([])
tmp_anno_list = [anno_list[0]]
pair_list = []
for ind in range(len(anno_list) - 1):
cont_count = int(anno_list[ind + 1][0]) - int(anno_list[ind][-1])
if cont_count > 1:
pair_list.extend(tmp_anno_list)
tmp_anno_list = [anno_list[ind + 1]]
continue
tmp_anno_list.append(anno_list[ind + 1])
pair_list.extend(tmp_anno_list)
return pair_list

+ 331
- 0
modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py View File

@@ -0,0 +1,331 @@
"""
Modified from original implementation in torchvision
"""

from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor


def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
groups: int = 1,
dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion: int = 1

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
'Dilation > 1 not supported in BasicBlock')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion: int = 4

def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x: Tensor) -> Tensor:
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
in_channel_dim: int = 3,
zero_init_residual: bool = False,
use_last_block_grid: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.use_last_block_grid = use_last_block_grid
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
in_channel_dim,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False,
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight,
0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight,
0) # type: ignore[arg-type]

def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
self.groups,
self.base_width,
previous_dilation,
norm_layer,
))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
))

return nn.Sequential(*layers)

def _forward_impl(self, x: Tensor, grid: bool, level: List, both: bool,
grid_only: bool) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)

if grid:
x_grid = []

if 3 in level:
x_grid.append(x.detach().clone())
if not both and len(level) == 1:
return x_grid

x = self.layer4(x)

if 4 in level:
x_grid.append(x.detach().clone())
if not both and len(level) == 1:
return x_grid

x = self.avgpool(x)
x = torch.flatten(x, 1)

if not grid or len(level) == 0:
return x

if grid_only:
return x_grid

if both:
return x, x_grid

return x

def forward(
self,
x: Tensor,
grid: bool = False,
level: List = [],
both: bool = False,
grid_only: bool = False,
) -> Tensor:
return self._forward_impl(x, grid, level, both, grid_only)


def resnet50(**kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

+ 132
- 0
modelscope/models/cv/movie_scene_segmentation/utils/trn.py View File

@@ -0,0 +1,132 @@
# ------------------------------------------------------------------------------------
# BaSSL
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# Github: https://github.com/kakaobrain/bassl
# ------------------------------------------------------------------------------------

import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEncoder


class ShotEmbedding(nn.Module):

def __init__(self, cfg):
super().__init__()

nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls
self.shot_embedding = nn.Linear(cfg.input_dim, cfg.hidden_size)
self.position_embedding = nn.Embedding(nn_size, cfg.hidden_size)
self.mask_embedding = nn.Embedding(2, cfg.input_dim, padding_idx=0)

# tf naming convention for layer norm
self.LayerNorm = nn.LayerNorm(cfg.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(cfg.hidden_dropout_prob)

self.register_buffer('pos_ids',
torch.arange(nn_size, dtype=torch.long))

def forward(
self,
shot_emb: torch.Tensor,
mask: torch.Tensor = None,
pos_ids: torch.Tensor = None,
) -> torch.Tensor:

assert len(shot_emb.size()) == 3

if pos_ids is None:
pos_ids = self.pos_ids

# this for mask embedding (un-masked ones remain unchanged)
if mask is not None:
self.mask_embedding.weight.data[0, :].fill_(0)
mask_emb = self.mask_embedding(mask.long())
shot_emb = (shot_emb * (1 - mask).float()[:, :, None]) + mask_emb

# we set [CLS] token to averaged feature
cls_emb = shot_emb.mean(dim=1)

# embedding shots
shot_emb = torch.cat([cls_emb[:, None, :], shot_emb], dim=1)
shot_emb = self.shot_embedding(shot_emb)
pos_emb = self.position_embedding(pos_ids)
embeddings = shot_emb + pos_emb[None, :]
embeddings = self.dropout(self.LayerNorm(embeddings))
return embeddings


class TransformerCRN(nn.Module):

def __init__(self, cfg):
super().__init__()

self.pooling_method = cfg.pooling_method
self.shot_embedding = ShotEmbedding(cfg)
self.encoder = BertEncoder(cfg)

nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls
self.register_buffer(
'attention_mask',
self._get_extended_attention_mask(
torch.ones((1, nn_size)).float()),
)

def forward(
self,
shot: torch.Tensor,
mask: torch.Tensor = None,
pos_ids: torch.Tensor = None,
pooling_method: str = None,
):
if self.attention_mask.shape[1] != (shot.shape[1] + 1):
n_shot = shot.shape[1] + 1 # +1 for CLS token
attention_mask = self._get_extended_attention_mask(
torch.ones((1, n_shot), dtype=torch.float, device=shot.device))
else:
attention_mask = self.attention_mask

shot_emb = self.shot_embedding(shot, mask=mask, pos_ids=pos_ids)
encoded_emb = self.encoder(
shot_emb, attention_mask=attention_mask).last_hidden_state

return encoded_emb, self.pooler(
encoded_emb, pooling_method=pooling_method)

def pooler(self, sequence_output, pooling_method=None):
if pooling_method is None:
pooling_method = self.pooling_method

if pooling_method == 'cls':
return sequence_output[:, 0, :]
elif pooling_method == 'avg':
return sequence_output[:, 1:].mean(dim=1)
elif pooling_method == 'max':
return sequence_output[:, 1:].max(dim=1)[0]
elif pooling_method == 'center':
cidx = sequence_output.shape[1] // 2
return sequence_output[:, cidx, :]
else:
raise ValueError

def _get_extended_attention_mask(self, attention_mask):

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f'Wrong shape for attention_mask (shape {attention_mask.shape})'
)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

+ 3
- 0
modelscope/msdatasets/task_datasets/__init__.py View File

@@ -9,7 +9,9 @@ if TYPE_CHECKING:
from .torch_base_dataset import TorchTaskDataset
from .veco_dataset import VecoDataset
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
from .movie_scene_segmentation import MovieSceneSegmentationDataset
from .video_summarization_dataset import VideoSummarizationDataset

else:
_import_structure = {
'base': ['TaskDataset'],
@@ -19,6 +21,7 @@ else:
'image_instance_segmentation_coco_dataset':
['ImageInstanceSegmentationCocoDataset'],
'video_summarization_dataset': ['VideoSummarizationDataset'],
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'],
}
import sys



+ 1
- 0
modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py View File

@@ -0,0 +1 @@
from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset

+ 173
- 0
modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py View File

@@ -0,0 +1,173 @@
# ---------------------------------------------------------------------------------------------------
# The implementation is built upon BaSSL, publicly available at https://github.com/kakaobrain/bassl
# ---------------------------------------------------------------------------------------------------
import copy
import os
import os.path as osp
import random

import json
import torch
from torchvision.datasets.folder import pil_loader

from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks
from . import sampler

DATASET_STRUCTURE = {
'train': {
'annotation': 'anno/train.json',
'images': 'keyf_240p',
'feat': 'feat'
},
'test': {
'annotation': 'anno/test.json',
'images': 'keyf_240p',
'feat': 'feat'
}
}


@TASK_DATASETS.register_module(
Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
class MovieSceneSegmentationDataset(TorchTaskDataset):
"""dataset for movie scene segmentation.

Args:
split_config (dict): Annotation file path. {"train":"xxxxx"}
data_root (str, optional): Data root for ``ann_file``,
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
"""

def __init__(self, **kwargs):
split_config = kwargs['split_config']

self.data_root = next(iter(split_config.values()))
if not osp.exists(self.data_root):
self.data_root = osp.dirname(self.data_root)
assert osp.exists(self.data_root)

self.split = next(iter(split_config.keys()))
self.preprocessor = kwargs['preprocessor']

self.ann_file = osp.join(self.data_root,
DATASET_STRUCTURE[self.split]['annotation'])
self.img_prefix = osp.join(self.data_root,
DATASET_STRUCTURE[self.split]['images'])
self.feat_prefix = osp.join(self.data_root,
DATASET_STRUCTURE[self.split]['feat'])

self.test_mode = kwargs['test_mode']
if self.test_mode:
self.preprocessor.eval()
else:
self.preprocessor.train()

self.cfg = kwargs.pop('cfg', None)

self.num_keyframe = self.cfg.num_keyframe if self.cfg is not None else 3
self.use_single_keyframe = self.cfg.use_single_keyframe if self.cfg is not None else False

self.load_data()
self.init_sampler(self.cfg)

def __len__(self):
"""Total number of samples of data."""
return len(self.anno_data)

def __getitem__(self, idx: int):
data = self.anno_data[
idx] # {"video_id", "shot_id", "num_shot", "boundary_label"}
vid, sid = data['video_id'], data['shot_id']
num_shot = data['num_shot']

shot_idx = self.shot_sampler(int(sid), num_shot)

video = self.load_shot_list(vid, shot_idx)
if self.preprocessor is None:
video = torch.stack(video, dim=0)
video = video.view(-1, self.num_keyframe, 3, 224, 224)
else:
video = self.preprocessor(video)

payload = {
'idx': idx,
'vid': vid,
'sid': sid,
'video': video,
'label': abs(data['boundary_label']), # ignore -1 label.
}
return payload

def load_data(self):
self.tmpl = '{}/shot_{}_img_{}.jpg' # video_id, shot_id, shot_num

if not self.test_mode:
with open(self.ann_file) as f:
self.anno_data = json.load(f)
self.vidsid2label = {
f"{it['video_id']}_{it['shot_id']}": it['boundary_label']
for it in self.anno_data
}
else:
with open(self.ann_file) as f:
self.anno_data = json.load(f)

def init_sampler(self, cfg):
# shot sampler
if cfg is not None:
self.sampling_method = cfg.sampling_method.name
sampler_args = copy.deepcopy(
cfg.sampling_method.params.get(self.sampling_method, {}))
if self.sampling_method == 'instance':
self.shot_sampler = sampler.InstanceShotSampler()
elif self.sampling_method == 'temporal':
self.shot_sampler = sampler.TemporalShotSampler(**sampler_args)
elif self.sampling_method == 'shotcol':
self.shot_sampler = sampler.SequenceShotSampler(**sampler_args)
elif self.sampling_method == 'bassl':
self.shot_sampler = sampler.SequenceShotSampler(**sampler_args)
elif self.sampling_method == 'bassl+shotcol':
self.shot_sampler = sampler.SequenceShotSampler(**sampler_args)
elif self.sampling_method == 'sbd':
self.shot_sampler = sampler.NeighborShotSampler(**sampler_args)
else:
raise NotImplementedError
else:
self.shot_sampler = sampler.NeighborShotSampler()

def load_shot_list(self, vid, shot_idx):
shot_list = []
cache = {}
for sidx in shot_idx:
vidsid = f'{vid}_{sidx:04d}'
if vidsid in cache:
shot = cache[vidsid]
else:
shot_path = os.path.join(
self.img_prefix, self.tmpl.format(vid, f'{sidx:04d}',
'{}'))
shot = self.load_shot_keyframes(shot_path)
cache[vidsid] = shot
shot_list.extend(shot)
return shot_list

def load_shot_keyframes(self, path):
shot = None
if not self.test_mode and self.use_single_keyframe:
# load one randomly sampled keyframe
shot = [
pil_loader(
path.format(random.randint(0, self.num_keyframe - 1)))
]
else:
# load all keyframes
shot = [
pil_loader(path.format(i)) for i in range(self.num_keyframe)
]
assert shot is not None
return shot

+ 102
- 0
modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py View File

@@ -0,0 +1,102 @@
# ------------------------------------------------------------------------------------
# BaSSL
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# Github: https://github.com/kakaobrain/bassl
# ------------------------------------------------------------------------------------

import random

import numpy as np


class InstanceShotSampler:
""" This is for instance at pre-training stage """

def __call__(self, center_sid: int, *args, **kwargs):
return center_sid


class TemporalShotSampler:
""" This is for temporal at pre-training stage """

def __init__(self, neighbor_size: int):
self.N = neighbor_size

def __call__(self, center_sid: int, total_num_shot: int):
""" we randomly sample one shot from neighbor shots within local temporal window
"""
shot_idx = center_sid + np.arange(
-self.N, self.N + 1
) # total number of neighbor shots = 2N+1 (query (1) + neighbors (2*N))
shot_idx = np.clip(shot_idx, 0,
total_num_shot) # deal with out-of-boundary indices
shot_idx = random.choice(
np.unique(np.delete(shot_idx, np.where(shot_idx == center_sid))))
return shot_idx


class SequenceShotSampler:
""" This is for bassl or shotcol at pre-training stage """

def __init__(self, neighbor_size: int, neighbor_interval: int):
self.interval = neighbor_interval
self.window_size = neighbor_size * self.interval # temporal coverage

def __call__(self,
center_sid: int,
total_num_shot: int,
sparse_method: str = 'edge'):
"""
Args:
center_sid: index of center shot
total_num_shot: last index of shot for given video
sparse_stride: stride to sample sparse ones from dense sequence
for curriculum learning
"""

dense_shot_idx = center_sid + np.arange(
-self.window_size, self.window_size + 1,
self.interval) # total number of shots = 2*neighbor_size+1

if dense_shot_idx[0] < 0:
# if center_sid is near left-side of video, we shift window rightward
# so that the leftmost index is 0
dense_shot_idx -= dense_shot_idx[0]
elif dense_shot_idx[-1] > (total_num_shot - 1):
# if center_sid is near right-side of video, we shift window leftward
# so that the rightmost index is total_num_shot - 1
dense_shot_idx -= dense_shot_idx[-1] - (total_num_shot - 1)

# to deal with videos that have smaller number of shots than window size
dense_shot_idx = np.clip(dense_shot_idx, 0, total_num_shot)

if sparse_method == 'edge':
# in this case, we use two edge shots as sparse sequence
sparse_stride = len(dense_shot_idx) - 1
sparse_idx_to_dense = np.arange(0, len(dense_shot_idx),
sparse_stride)
elif sparse_method == 'edge+center':
# in this case, we use two edge shots + center shot as sparse sequence
sparse_idx_to_dense = np.array(
[0, len(dense_shot_idx) - 1,
len(dense_shot_idx) // 2])

shot_idx = [sparse_idx_to_dense, dense_shot_idx]
return shot_idx


class NeighborShotSampler:
""" This is for scene boundary detection (sbd), i.e., fine-tuning stage """

def __init__(self, neighbor_size: int = 8):
self.neighbor_size = neighbor_size

def __call__(self, center_sid: int, total_num_shot: int):
# total number of shots = 2 * neighbor_size + 1
shot_idx = center_sid + np.arange(-self.neighbor_size,
self.neighbor_size + 1)
shot_idx = np.clip(shot_idx, 0,
total_num_shot) # for out-of-boundary indices

return shot_idx

+ 18
- 0
modelscope/outputs.py View File

@@ -35,6 +35,8 @@ class OutputKeys(object):
UUID = 'uuid'
WORD = 'word'
KWS_LIST = 'kws_list'
SPLIT_VIDEO_NUM = 'split_video_num'
SPLIT_META_DICT = 'split_meta_dict'


TASK_OUTPUTS = {
@@ -241,6 +243,22 @@ TASK_OUTPUTS = {
# }
Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG],

# movide scene segmentation result for a single video
# {
# "split_video_num":3,
# "split_meta_dict":
# {
# scene_id:
# {
# "shot": [0,1,2],
# "frame": [start_frame, end_frame]
# }
# }
#
# }
Tasks.movie_scene_segmentation:
[OutputKeys.SPLIT_VIDEO_NUM, OutputKeys.SPLIT_META_DICT],

# ============ nlp tasks ===================

# text classification result for single sample


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -144,6 +144,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
Tasks.movie_scene_segmentation:
(Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet')
}




+ 5
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -42,6 +42,8 @@ if TYPE_CHECKING:
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_try_on_pipeline import VirtualTryonPipeline
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline

else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
@@ -90,7 +92,9 @@ else:
'video_category_pipeline': ['VideoCategoryPipeline'],
'virtual_try_on_pipeline': ['VirtualTryonPipeline'],
'easycv_pipeline':
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline']
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'],
'movie_scene_segmentation_pipeline':
['MovieSceneSegmentationPipeline'],
}

import sys


+ 67
- 0
modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py View File

@@ -0,0 +1,67 @@
from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.movie_scene_segmentation,
module_name=Pipelines.movie_scene_segmentation)
class MovieSceneSegmentationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""use `model` to create a movie scene segmentation pipeline for prediction

Args:
model: model id on modelscope hub
"""
_device = kwargs.pop('device', 'gpu')
if torch.cuda.is_available() and _device == 'gpu':
device = 'gpu'
else:
device = 'cpu'
super().__init__(model=model, device=device, **kwargs)

logger.info('Load model done!')

def preprocess(self, input: Input) -> Dict[str, Any]:
""" use pyscenedetect to detect shot from the input video, and generate key-frame jpg, anno.ndjson, and shot-frame.txt
Then use shot-encoder to encoder feat of the detected key-frame

Args:
input: path of the input video

"""
self.input_video_pth = input
if isinstance(input, str):
shot_feat, sid = self.model.preprocess(input)
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')

result = {'sid': sid, 'shot_feat': shot_feat}

return result

def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
output = self.model.inference(input)
return output

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
data = {'input_video_pth': self.input_video_pth, 'feat': inputs}
video_num, meta_dict = self.model.postprocess(data)
result = {
OutputKeys.SPLIT_VIDEO_NUM: video_num,
OutputKeys.SPLIT_META_DICT: meta_dict
}
return result

+ 2
- 2
modelscope/preprocessors/__init__.py View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
from .video import ReadVideoData
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
from .star import ConversationalTextToSqlPreprocessor

else:
@@ -37,7 +37,7 @@ else:
'common': ['Compose', 'ToTensor', 'Filter'],
'audio': ['LinearAECAndFbank'],
'asr': ['WavToScp'],
'video': ['ReadVideoData'],
'video': ['ReadVideoData', 'MovieSceneSegmentationPreprocessor'],
'image': [
'LoadImage', 'load_image', 'ImageColorEnhanceFinetunePreprocessor',
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor'


+ 19
- 0
modelscope/preprocessors/movie_scene_segmentation/__init__.py View File

@@ -0,0 +1,19 @@
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .transforms import get_transform
else:
_import_structure = {
'transforms': ['get_transform'],
}

import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 312
- 0
modelscope/preprocessors/movie_scene_segmentation/transforms.py View File

@@ -0,0 +1,312 @@
# ------------------------------------------------------------------------------------
# The codes below partially refer to the BaSSL
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# Github: https://github.com/kakaobrain/bassl
# ------------------------------------------------------------------------------------
import numbers
import os.path as osp
import random
from typing import List

import numpy as np
import torch
import torchvision.transforms as TF
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter


def get_transform(lst):
assert len(lst) > 0
transform_lst = []
for item in lst:
transform_lst.append(build_transform(item))
transform = TF.Compose(transform_lst)
return transform


def build_transform(cfg):
assert isinstance(cfg, dict)
cfg = cfg.copy()
type = cfg.pop('type')

if type == 'VideoResizedCenterCrop':
return VideoResizedCenterCrop(**cfg)
elif type == 'VideoToTensor':
return VideoToTensor(**cfg)
elif type == 'VideoRandomResizedCrop':
return VideoRandomResizedCrop(**cfg)
elif type == 'VideoRandomHFlip':
return VideoRandomHFlip()
elif type == 'VideoRandomColorJitter':
return VideoRandomColorJitter(**cfg)
elif type == 'VideoRandomGaussianBlur':
return VideoRandomGaussianBlur(**cfg)
else:
raise NotImplementedError


class VideoResizedCenterCrop(torch.nn.Module):

def __init__(self, image_size, crop_size):
self.tfm = TF.Compose([
TF.Resize(size=image_size, interpolation=Image.BICUBIC),
TF.CenterCrop(crop_size),
])

def __call__(self, imgmap):
assert isinstance(imgmap, list)
return [self.tfm(img) for img in imgmap]


class VideoToTensor(torch.nn.Module):

def __init__(self, mean=None, std=None, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace

assert self.mean is not None
assert self.std is not None

def __to_tensor__(self, img):
return F.to_tensor(img)

def __normalize__(self, img):
return F.normalize(img, self.mean, self.std, self.inplace)

def __call__(self, imgmap):
assert isinstance(imgmap, list)
return [self.__normalize__(self.__to_tensor__(img)) for img in imgmap]


class VideoRandomResizedCrop(torch.nn.Module):

def __init__(self, size, bottom_area=0.2):
self.p = 1.0
self.interpolation = Image.BICUBIC
self.size = size
self.bottom_area = bottom_area

def __call__(self, imgmap):
assert isinstance(imgmap, list)
if random.random() < self.p: # do RandomResizedCrop, consistent=True
top, left, height, width = TF.RandomResizedCrop.get_params(
imgmap[0],
scale=(self.bottom_area, 1.0),
ratio=(3 / 4.0, 4 / 3.0))
return [
F.resized_crop(
img=img,
top=top,
left=left,
height=height,
width=width,
size=(self.size, self.size),
) for img in imgmap
]
else:
return [
F.resize(img=img, size=[self.size, self.size])
for img in imgmap
]


class VideoRandomHFlip(torch.nn.Module):

def __init__(self, consistent=True, command=None, seq_len=0):
self.consistent = consistent
if seq_len != 0:
self.consistent = False
if command == 'left':
self.threshold = 0
elif command == 'right':
self.threshold = 1
else:
self.threshold = 0.5
self.seq_len = seq_len

def __call__(self, imgmap):
assert isinstance(imgmap, list)
if self.consistent:
if random.random() < self.threshold:
return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap]
else:
return imgmap
else:
result = []
for idx, i in enumerate(imgmap):
if idx % self.seq_len == 0:
th = random.random()
if th < self.threshold:
result.append(i.transpose(Image.FLIP_LEFT_RIGHT))
else:
result.append(i)
assert len(result) == len(imgmap)
return result


class VideoRandomColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
"""

def __init__(
self,
brightness=0,
contrast=0,
saturation=0,
hue=0,
consistent=True,
p=1.0,
seq_len=0,
):
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(
hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
self.consistent = consistent
self.threshold = p
self.seq_len = seq_len

def _check_input(self,
value,
name,
center=1,
bound=(0, float('inf')),
clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(
'If {} is a single number, it must be non negative.'.
format(name))
value = [center - value, center + value]
if clip_first_on_zero:
value[0] = max(value[0], 0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError('{} values should be between {}'.format(
name, bound))
else:
raise TypeError(
'{} should be a single number or a list/tuple with lenght 2.'.
format(name))

# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value

@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []

if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(
TF.Lambda(
lambda img: F.adjust_brightness(img, brightness_factor)))

if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(
TF.Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(
TF.Lambda(
lambda img: F.adjust_saturation(img, saturation_factor)))

if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(
TF.Lambda(lambda img: F.adjust_hue(img, hue_factor)))

random.shuffle(transforms)
transform = TF.Compose(transforms)

return transform

def __call__(self, imgmap):
assert isinstance(imgmap, list)
if random.random() < self.threshold: # do ColorJitter
if self.consistent:
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)

return [transform(i) for i in imgmap]
else:
if self.seq_len == 0:
return [
self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)(img)
for img in imgmap
]
else:
result = []
for idx, img in enumerate(imgmap):
if idx % self.seq_len == 0:
transform = self.get_params(
self.brightness,
self.contrast,
self.saturation,
self.hue,
)
result.append(transform(img))
return result

else:
return imgmap

def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
return format_string


class VideoRandomGaussianBlur(torch.nn.Module):

def __init__(self, radius_min=0.1, radius_max=2.0, p=0.5):
self.radius_min = radius_min
self.radius_max = radius_max
self.p = p

def __call__(self, imgmap):
assert isinstance(imgmap, list)
if random.random() < self.p:
result = []
for _, img in enumerate(imgmap):
_radius = random.uniform(self.radius_min, self.radius_max)
result.append(
img.filter(ImageFilter.GaussianBlur(radius=_radius)))
return result
else:
return imgmap


def apply_transform(images, trans):
return torch.stack(trans(images), dim=0)

+ 45
- 0
modelscope/preprocessors/video.py View File

@@ -9,6 +9,12 @@ import torchvision.transforms._transforms_video as transforms
from decord import VideoReader
from torchvision.transforms import Compose

from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields, ModeKeys
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS


def ReadVideoData(cfg, video_path):
""" simple interface to load video frames from file
@@ -227,3 +233,42 @@ class KineticsResizedCrop(object):

def __call__(self, clip):
return self._get_controlled_crop(clip)


@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor)
class MovieSceneSegmentationPreprocessor(Preprocessor):

def __init__(self, *args, **kwargs):
"""
movie scene segmentation preprocessor
"""
super().__init__(*args, **kwargs)

self.is_train = kwargs.pop('is_train', True)
self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None)
self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None)
self.num_keyframe = kwargs.pop('num_keyframe', 3)

from .movie_scene_segmentation import get_transform
self.train_transform = get_transform(self.preprocessor_train_cfg)
self.test_transform = get_transform(self.preprocessor_test_cfg)

def train(self):
self.is_train = True
return

def eval(self):
self.is_train = False
return

@type_assert(object, object)
def __call__(self, results):
if self.is_train:
transforms = self.train_transform
else:
transforms = self.test_transform

results = torch.stack(transforms(results), dim=0)
results = results.view(-1, self.num_keyframe, 3, 224, 224)
return results

+ 3
- 2
modelscope/trainers/__init__.py View File

@@ -8,7 +8,8 @@ if TYPE_CHECKING:
from .base import DummyTrainer
from .builder import build_trainer
from .cv import (ImageInstanceSegmentationTrainer,
ImagePortraitEnhancementTrainer)
ImagePortraitEnhancementTrainer,
MovieSceneSegmentationTrainer)
from .multi_modal import CLIPTrainer
from .nlp import SequenceClassificationTrainer
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer
@@ -21,7 +22,7 @@ else:
'builder': ['build_trainer'],
'cv': [
'ImageInstanceSegmentationTrainer',
'ImagePortraitEnhancementTrainer'
'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer'
],
'multi_modal': ['CLIPTrainer'],
'nlp': ['SequenceClassificationTrainer'],


+ 2
- 0
modelscope/trainers/cv/__init__.py View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .image_instance_segmentation_trainer import \
ImageInstanceSegmentationTrainer
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer
from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer

else:
_import_structure = {
@@ -14,6 +15,7 @@ else:
['ImageInstanceSegmentationTrainer'],
'image_portrait_enhancement_trainer':
['ImagePortraitEnhancementTrainer'],
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer']
}

import sys


+ 20
- 0
modelscope/trainers/cv/movie_scene_segmentation_trainer.py View File

@@ -0,0 +1,20 @@
from modelscope.metainfo import Trainers
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.trainer import EpochBasedTrainer


@TRAINERS.register_module(module_name=Trainers.movie_scene_segmentation)
class MovieSceneSegmentationTrainer(EpochBasedTrainer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def train(self, *args, **kwargs):
super().train(*args, **kwargs)

def evaluate(self, *args, **kwargs):
metric_values = super().evaluate(*args, **kwargs)
return metric_values

def prediction_step(self, model, inputs):
pass

+ 1
- 0
modelscope/utils/constant.py View File

@@ -62,6 +62,7 @@ class CVTasks(object):
video_embedding = 'video-embedding'
virtual_try_on = 'virtual-try-on'
crowd_counting = 'crowd-counting'
movie_scene_segmentation = 'movie-scene-segmentation'

# reid and tracking
video_single_object_tracking = 'video-single-object-tracking'


+ 1
- 0
requirements/cv.txt View File

@@ -21,6 +21,7 @@ regex
scikit-image>=0.19.3
scikit-learn>=0.20.1
shapely
shotdetect_scenedetect_lgss
tensorflow-estimator>=1.15.1
tf_slim
timm>=0.4.9


+ 6
- 0
tests/msdatasets/test_ms_dataset.py View File

@@ -31,6 +31,12 @@ class ImgPreprocessor(Preprocessor):

class MsDatasetTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_movie_scene_seg_toydata(self):
ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train')
print(ms_ds_train._hf_ds.config_kwargs)
assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_coco(self):
ms_ds_train = MsDataset.load(


+ 36
- 0
tests/pipelines/test_movie_scene_segmentation.py View File

@@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class MovieSceneSegmentationTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_movie_scene_segmentation(self):
input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4'
model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
movie_scene_segmentation_pipeline = pipeline(
Tasks.movie_scene_segmentation, model=model_id)
result = movie_scene_segmentation_pipeline(input_location)
if result:
print(result)
else:
raise ValueError('process error')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_movie_scene_segmentation_with_default_task(self):
input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4'
movie_scene_segmentation_pipeline = pipeline(
Tasks.movie_scene_segmentation)
result = movie_scene_segmentation_pipeline(input_location)
if result:
print(result)
else:
raise ValueError('process error')


if __name__ == '__main__':
unittest.main()

+ 109
- 0
tests/trainers/test_movie_scene_segmentation_trainer.py View File

@@ -0,0 +1,109 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import zipfile

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.cv.movie_scene_segmentation import \
MovieSceneSegmentationModel
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestImageInstanceSegmentationTrainer(unittest.TestCase):

model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

cache_path = snapshot_download(self.model_id)
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)

max_epochs = cfg.train.max_epochs

train_data_cfg = ConfigDict(
name='movie_scene_seg_toydata',
split='train',
cfg=cfg.preprocessor,
test_mode=False)

test_data_cfg = ConfigDict(
name='movie_scene_seg_toydata',
split='test',
cfg=cfg.preprocessor,
test_mode=True)

self.train_dataset = MsDataset.load(
dataset_name=train_data_cfg.name,
split=train_data_cfg.split,
namespace=train_data_cfg.namespace,
cfg=train_data_cfg.cfg,
test_mode=train_data_cfg.test_mode)
assert next(
iter(self.train_dataset.config_kwargs['split_config'].values()))

self.test_dataset = MsDataset.load(
dataset_name=test_data_cfg.name,
split=test_data_cfg.split,
namespace=test_data_cfg.namespace,
cfg=test_data_cfg.cfg,
test_mode=test_data_cfg.test_mode)
assert next(
iter(self.test_dataset.config_kwargs['split_config'].values()))

self.max_epochs = max_epochs

self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(
name=Trainers.movie_scene_segmentation, default_args=kwargs)
trainer.train()
results_files = os.listdir(trainer.work_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(self.model_id)
model = MovieSceneSegmentationModel.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir=tmp_dir)

trainer = build_trainer(
name=Trainers.movie_scene_segmentation, default_args=kwargs)
trainer.train()
results_files = os.listdir(trainer.work_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save