diff --git a/data/test/videos/movie_scene_segmentation_test_video.mp4 b/data/test/videos/movie_scene_segmentation_test_video.mp4 new file mode 100644 index 00000000..ee6ed528 --- /dev/null +++ b/data/test/videos/movie_scene_segmentation_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59fa397b01dc4c9b67a19ca42f149287b9c4e7b2158aba5d07d2db88af87b23f +size 126815483 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 908ee011..f1179be8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -27,6 +27,7 @@ class Models(object): video_summarization = 'pgl-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' + resnet50_bert = 'resnet50-bert' # EasyCV models yolox = 'YOLOX' @@ -133,6 +134,7 @@ class Pipelines(object): video_summarization = 'googlenet_pgl_video_summarization' image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' + movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' # nlp tasks sentence_similarity = 'sentence-similarity' @@ -195,6 +197,7 @@ class Trainers(object): image_instance_segmentation = 'image-instance-segmentation' image_portrait_enhancement = 'image-portrait-enhancement' video_summarization = 'video-summarization' + movie_scene_segmentation = 'movie-scene-segmentation' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -223,6 +226,7 @@ class Preprocessors(object): image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' video_summarization_preprocessor = 'video-summarization-preprocessor' + movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -279,6 +283,8 @@ class Metrics(object): # metrics for image-portrait-enhancement task image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' video_summarization_metric = 'video-summarization-metric' + # metric for movie-scene-segmentation task + movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index c74b475e..d3975a2c 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from .text_generation_metric import TextGenerationMetric from .token_classification_metric import TokenClassificationMetric from .video_summarization_metric import VideoSummarizationMetric + from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric else: _import_structure = { @@ -32,6 +33,7 @@ else: 'text_generation_metric': ['TextGenerationMetric'], 'token_classification_metric': ['TokenClassificationMetric'], 'video_summarization_metric': ['VideoSummarizationMetric'], + 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], } import sys diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 869a1ab2..800e3508 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -34,6 +34,7 @@ task_default_metrics = { Tasks.video_summarization: [Metrics.video_summarization_metric], Tasks.image_captioning: [Metrics.text_gen_metric], Tasks.visual_question_answering: [Metrics.text_gen_metric], + Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], } diff --git a/modelscope/metrics/movie_scene_segmentation_metric.py b/modelscope/metrics/movie_scene_segmentation_metric.py new file mode 100644 index 00000000..56bdbd1c --- /dev/null +++ b/modelscope/metrics/movie_scene_segmentation_metric.py @@ -0,0 +1,52 @@ +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.movie_scene_segmentation_metric) +class MovieSceneSegmentationMetric(Metric): + """The metric computation class for movie scene segmentation classes. + """ + + def __init__(self): + self.preds = [] + self.labels = [] + self.eps = 1e-5 + + def add(self, outputs: Dict, inputs: Dict): + preds = outputs['pred'] + labels = inputs['label'] + self.preds.extend(preds) + self.labels.extend(labels) + + def evaluate(self): + gts = np.array(torch_nested_numpify(torch_nested_detach(self.labels))) + prob = np.array(torch_nested_numpify(torch_nested_detach(self.preds))) + + gt_one = gts == 1 + gt_zero = gts == 0 + pred_one = prob == 1 + pred_zero = prob == 0 + + tp = (gt_one * pred_one).sum() + fp = (gt_zero * pred_one).sum() + fn = (gt_one * pred_zero).sum() + + precision = 100.0 * tp / (tp + fp + self.eps) + recall = 100.0 * tp / (tp + fn + self.eps) + f1 = 2 * precision * recall / (precision + recall) + + return { + MetricKeys.F1: f1, + MetricKeys.RECALL: recall, + MetricKeys.PRECISION: precision + } diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 10040637..331f23bd 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -9,8 +9,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_panoptic_segmentation, image_portrait_enhancement, image_reid_person, image_semantic_segmentation, image_to_image_generation, image_to_image_translation, - object_detection, product_retrieval_embedding, - realtime_object_detection, salient_detection, super_resolution, + movie_scene_segmentation, object_detection, + product_retrieval_embedding, realtime_object_detection, + salient_detection, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/movie_scene_segmentation/__init__.py b/modelscope/models/cv/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..25dcda96 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import MovieSceneSegmentationModel + from .datasets import MovieSceneSegmentationDataset + +else: + _import_structure = { + 'model': ['MovieSceneSegmentationModel'], + 'datasets': ['MovieSceneSegmentationDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/movie_scene_segmentation/get_model.py b/modelscope/models/cv/movie_scene_segmentation/get_model.py new file mode 100644 index 00000000..5c66fc02 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/get_model.py @@ -0,0 +1,45 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +from .utils.shot_encoder import resnet50 +from .utils.trn import TransformerCRN + + +def get_shot_encoder(cfg): + name = cfg['model']['shot_encoder']['name'] + shot_encoder_args = cfg['model']['shot_encoder'][name] + if name == 'resnet': + depth = shot_encoder_args['depth'] + if depth == 50: + shot_encoder = resnet50(**shot_encoder_args['params'], ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + return shot_encoder + + +def get_contextual_relation_network(cfg): + crn = None + + if cfg['model']['contextual_relation_network']['enabled']: + name = cfg['model']['contextual_relation_network']['name'] + crn_args = cfg['model']['contextual_relation_network']['params'][name] + if name == 'trn': + sampling_name = cfg['model']['loss']['sampling_method']['name'] + crn_args['neighbor_size'] = ( + 2 * cfg['model']['loss']['sampling_method']['params'] + [sampling_name]['neighbor_size']) + crn = TransformerCRN(crn_args) + else: + raise NotImplementedError + + return crn + + +__all__ = ['get_shot_encoder', 'get_contextual_relation_network'] diff --git a/modelscope/models/cv/movie_scene_segmentation/model.py b/modelscope/models/cv/movie_scene_segmentation/model.py new file mode 100644 index 00000000..e9576963 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/model.py @@ -0,0 +1,192 @@ +import os +import os.path as osp +from typing import Any, Dict + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as TF +from PIL import Image +from shotdetect_scenedetect_lgss import shot_detect + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .get_model import get_contextual_relation_network, get_shot_encoder +from .utils.save_op import get_pred_boundary, pred2scene, scene2video + +logger = get_logger() + + +@MODELS.register_module( + Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) +class MovieSceneSegmentationModel(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params = torch.load(model_path, map_location='cpu') + + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + + def load_param_with_prefix(prefix, model, src_params): + own_state = model.state_dict() + for name, param in own_state.items(): + src_name = prefix + '.' + name + own_state[name] = src_params[src_name] + + model.load_state_dict(own_state) + + self.shot_encoder = get_shot_encoder(self.cfg) + load_param_with_prefix('shot_encoder', self.shot_encoder, params) + self.crn = get_contextual_relation_network(self.cfg) + load_param_with_prefix('crn', self.crn, params) + + crn_name = self.cfg.model.contextual_relation_network.name + hdim = self.cfg.model.contextual_relation_network.params[crn_name][ + 'hidden_size'] + self.head_sbd = nn.Linear(hdim, 2) + load_param_with_prefix('head_sbd', self.head_sbd, params) + + self.test_transform = TF.Compose([ + TF.Resize(size=256, interpolation=Image.BICUBIC), + TF.CenterCrop(224), + TF.ToTensor(), + TF.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + self.infer_result = {'vid': [], 'sid': [], 'pred': []} + sampling_method = self.cfg.dataset.sampling_method.name + self.neighbor_size = self.cfg.dataset.sampling_method.params[ + sampling_method].neighbor_size + + self.eps = 1e-5 + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + data = inputs['video'] + labels = inputs['label'] + outputs = self.shared_step(data) + + loss = F.cross_entropy( + outputs.squeeze(), labels.squeeze(), reduction='none') + lpos = labels == 1 + lneg = labels == 0 + + pp, nn = 1, 1 + wp = (pp / float(pp + nn)) * lpos / (lpos.sum() + self.eps) + wn = (nn / float(pp + nn)) * lneg / (lneg.sum() + self.eps) + w = wp + wn + loss = (w * loss).sum() + + probs = torch.argmax(outputs, dim=1) + + re = dict(pred=probs, loss=loss) + return re + + def inference(self, batch): + logger.info('Begin scene detect ......') + bs = self.cfg.pipeline.batch_size_per_gpu + sids = batch['sid'] + inputs = batch['shot_feat'] + + shot_num = len(sids) + cnt = shot_num // bs + 1 + + for i in range(cnt): + start = i * bs + end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num + input_ = inputs[start:end] + sid_ = sids[start:end] + input_ = torch.stack(input_) + outputs = self.shared_step(input_) # shape [b,2] + prob = F.softmax(outputs, dim=1) + self.infer_result['sid'].extend(sid_.cpu().detach().numpy()) + self.infer_result['pred'].extend(prob[:, 1].cpu().detach().numpy()) + self.infer_result['pred'] = np.stack(self.infer_result['pred']) + + assert len(self.infer_result['sid']) == len(sids) + assert len(self.infer_result['pred']) == len(inputs) + return self.infer_result + + def shared_step(self, inputs): + with torch.no_grad(): + # infer shot encoder + shot_repr = self.extract_shot_representation(inputs) + assert len(shot_repr.shape) == 3 + + # infer CRN + _, pooled = self.crn(shot_repr, mask=None) + # infer boundary score + pred = self.head_sbd(pooled) + return pred + + def save_shot_feat(self, _repr): + feat = _repr.float().cpu().numpy() + pth = self.cfg.dataset.img_path + '/features' + os.makedirs(pth) + + for idx in range(_repr.shape[0]): + name = f'shot_{str(idx).zfill(4)}.npy' + name = osp.join(pth, name) + np.save(name, feat[idx]) + + def extract_shot_representation(self, + inputs: torch.Tensor) -> torch.Tensor: + """ inputs [b s k c h w] -> output [b d] """ + assert len(inputs.shape) == 6 # (B Shot Keyframe C H W) + b, s, k, c, h, w = inputs.shape + inputs = einops.rearrange(inputs, 'b s k c h w -> (b s) k c h w', s=s) + keyframe_repr = [self.shot_encoder(inputs[:, _k]) for _k in range(k)] + # [k (b s) d] -> [(b s) d] + shot_repr = torch.stack(keyframe_repr).mean(dim=0) + + shot_repr = einops.rearrange(shot_repr, '(b s) d -> b s d', s=s) + return shot_repr + + def postprocess(self, inputs: Dict[str, Any], **kwargs): + logger.info('Generate scene .......') + + pred_dict = inputs['feat'] + thres = self.cfg.pipeline.save_threshold + + anno_dict = get_pred_boundary(pred_dict, thres) + scene_dict, scene_list = pred2scene(self.shot2keyf, anno_dict) + if self.cfg.pipeline.save_split_scene: + re_dir = scene2video(inputs['input_video_pth'], scene_list, thres) + print(f'Split scene video saved to {re_dir}') + return len(scene_list), scene_dict + + def preprocess(self, inputs): + logger.info('Begin shot detect......') + shot_keyf_lst, anno, shot2keyf = shot_detect( + inputs, **self.cfg.preprocessor.shot_detect) + logger.info('Shot detect done!') + + single_shot_feat, sid = [], [] + for idx, one_shot in enumerate(shot_keyf_lst): + one_shot = [ + self.test_transform(one_frame) for one_frame in one_shot + ] + one_shot = torch.stack(one_shot, dim=0) + single_shot_feat.append(one_shot) + sid.append(idx) + single_shot_feat = torch.stack(single_shot_feat, dim=0) + shot_feat = [] + for idx, one_shot in enumerate(anno): + shot_idx = int(one_shot['shot_id']) + np.arange( + -self.neighbor_size, self.neighbor_size + 1) + shot_idx = np.clip(shot_idx, 0, one_shot['num_shot']) + _one_shot = single_shot_feat[shot_idx] + shot_feat.append(_one_shot) + self.shot2keyf = shot2keyf + self.anno = anno + return shot_feat, sid diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py b/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py new file mode 100644 index 00000000..3682726f --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py @@ -0,0 +1,3 @@ +from .save_op import get_pred_boundary, pred2scene, scene2video +from .shot_encoder import resnet50 +from .trn import TransformerCRN diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/head.py b/modelscope/models/cv/movie_scene_segmentation/utils/head.py new file mode 100644 index 00000000..20a87e66 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/head.py @@ -0,0 +1,29 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +import torch.nn as nn +import torch.nn.functional as F + + +class MlpHead(nn.Module): + + def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.model = nn.Sequential( + nn.Linear(self.input_dim, self.hidden_dim, bias=True), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.output_dim, bias=True), + ) + + def forward(self, x): + # x shape: [b t d] where t means the number of views + x = self.model(x) + return F.normalize(x, dim=-1) diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py new file mode 100644 index 00000000..d7c8c0ed --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py @@ -0,0 +1,118 @@ +# ---------------------------------------------------------------------------------- +# The codes below partially refer to the SceneSeg LGSS. +# Github: https://github.com/AnyiRao/SceneSeg +# ---------------------------------------------------------------------------------- +import os +import os.path as osp +import subprocess + +import cv2 +import numpy as np +from tqdm import tqdm + + +def get_pred_boundary(pred_dict, threshold=0.5): + pred = pred_dict['pred'] + tmp = (pred > threshold).astype(np.int32) + anno_dict = {} + for idx in range(len(tmp)): + anno_dict.update({str(pred_dict['sid'][idx]).zfill(4): int(tmp[idx])}) + return anno_dict + + +def pred2scene(shot2keyf, anno_dict): + scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict) + + scene_dict = {} + assert len(scene_list) == len(pair_list) + for scene_ind, scene_item in enumerate(scene_list): + scene_dict.update( + {scene_ind: { + 'shot': pair_list[scene_ind], + 'frame': scene_item + }}) + + return scene_dict, scene_list + + +def scene2video(source_movie_fn, scene_list, thres): + + vcap = cv2.VideoCapture(source_movie_fn) + fps = vcap.get(cv2.CAP_PROP_FPS) # video.fps + out_video_dir_fn = os.path.join(os.getcwd(), + f'pred_result/scene_video_{thres}') + os.makedirs(out_video_dir_fn, exist_ok=True) + + for scene_ind, scene_item in tqdm(enumerate(scene_list)): + scene = str(scene_ind).zfill(4) + start_frame = int(scene_item[0]) + end_frame = int(scene_item[1]) + start_time, end_time = start_frame / fps, end_frame / fps + duration_time = end_time - start_time + out_video_fn = os.path.join(out_video_dir_fn, + 'scene_{}.mp4'.format(scene)) + if os.path.exists(out_video_fn): + continue + call_list = ['ffmpeg'] + call_list += ['-v', 'quiet'] + call_list += [ + '-y', '-ss', + str(start_time), '-t', + str(duration_time), '-i', source_movie_fn + ] + call_list += ['-map_chapters', '-1'] + call_list += [out_video_fn] + subprocess.call(call_list) + return osp.join(os.getcwd(), 'pred_result') + + +def get_demo_scene_list(shot2keyf, anno_dict): + pair_list = get_pair_list(anno_dict) + + scene_list = [] + for pair in pair_list: + start_shot, end_shot = int(pair[0]), int(pair[-1]) + start_frame = shot2keyf[start_shot].split(' ')[0] + end_frame = shot2keyf[end_shot].split(' ')[1] + scene_list.append((start_frame, end_frame)) + return scene_list, pair_list + + +def get_pair_list(anno_dict): + sort_anno_dict_key = sorted(anno_dict.keys()) + tmp = 0 + tmp_list = [] + tmp_label_list = [] + anno_list = [] + anno_label_list = [] + for key in sort_anno_dict_key: + value = anno_dict.get(key) + tmp += value + tmp_list.append(key) + tmp_label_list.append(value) + if tmp == 1: + anno_list.append(tmp_list) + anno_label_list.append(tmp_label_list) + tmp = 0 + tmp_list = [] + tmp_label_list = [] + continue + if key == sort_anno_dict_key[-1]: + if len(tmp_list) > 0: + anno_list.append(tmp_list) + anno_label_list.append(tmp_label_list) + if len(anno_list) == 0: + return None + while [] in anno_list: + anno_list.remove([]) + tmp_anno_list = [anno_list[0]] + pair_list = [] + for ind in range(len(anno_list) - 1): + cont_count = int(anno_list[ind + 1][0]) - int(anno_list[ind][-1]) + if cont_count > 1: + pair_list.extend(tmp_anno_list) + tmp_anno_list = [anno_list[ind + 1]] + continue + tmp_anno_list.append(anno_list[ind + 1]) + pair_list.extend(tmp_anno_list) + return pair_list diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py b/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py new file mode 100644 index 00000000..7ad1907f --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py @@ -0,0 +1,331 @@ +""" +Modified from original implementation in torchvision +""" + +from typing import Any, Callable, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor + + +def conv3x3(in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + in_channel_dim: int = 3, + zero_init_residual: bool = False, + use_last_block_grid: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.use_last_block_grid = use_last_block_grid + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + in_channel_dim, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, + 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, + 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + )) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + )) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor, grid: bool, level: List, both: bool, + grid_only: bool) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + if grid: + x_grid = [] + + if 3 in level: + x_grid.append(x.detach().clone()) + if not both and len(level) == 1: + return x_grid + + x = self.layer4(x) + + if 4 in level: + x_grid.append(x.detach().clone()) + if not both and len(level) == 1: + return x_grid + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + if not grid or len(level) == 0: + return x + + if grid_only: + return x_grid + + if both: + return x, x_grid + + return x + + def forward( + self, + x: Tensor, + grid: bool = False, + level: List = [], + both: bool = False, + grid_only: bool = False, + ) -> Tensor: + return self._forward_impl(x, grid, level, both, grid_only) + + +def resnet50(**kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + """ + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/trn.py b/modelscope/models/cv/movie_scene_segmentation/utils/trn.py new file mode 100644 index 00000000..769e9ee4 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/trn.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +import torch +import torch.nn as nn +from transformers.models.bert.modeling_bert import BertEncoder + + +class ShotEmbedding(nn.Module): + + def __init__(self, cfg): + super().__init__() + + nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls + self.shot_embedding = nn.Linear(cfg.input_dim, cfg.hidden_size) + self.position_embedding = nn.Embedding(nn_size, cfg.hidden_size) + self.mask_embedding = nn.Embedding(2, cfg.input_dim, padding_idx=0) + + # tf naming convention for layer norm + self.LayerNorm = nn.LayerNorm(cfg.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(cfg.hidden_dropout_prob) + + self.register_buffer('pos_ids', + torch.arange(nn_size, dtype=torch.long)) + + def forward( + self, + shot_emb: torch.Tensor, + mask: torch.Tensor = None, + pos_ids: torch.Tensor = None, + ) -> torch.Tensor: + + assert len(shot_emb.size()) == 3 + + if pos_ids is None: + pos_ids = self.pos_ids + + # this for mask embedding (un-masked ones remain unchanged) + if mask is not None: + self.mask_embedding.weight.data[0, :].fill_(0) + mask_emb = self.mask_embedding(mask.long()) + shot_emb = (shot_emb * (1 - mask).float()[:, :, None]) + mask_emb + + # we set [CLS] token to averaged feature + cls_emb = shot_emb.mean(dim=1) + + # embedding shots + shot_emb = torch.cat([cls_emb[:, None, :], shot_emb], dim=1) + shot_emb = self.shot_embedding(shot_emb) + pos_emb = self.position_embedding(pos_ids) + embeddings = shot_emb + pos_emb[None, :] + embeddings = self.dropout(self.LayerNorm(embeddings)) + return embeddings + + +class TransformerCRN(nn.Module): + + def __init__(self, cfg): + super().__init__() + + self.pooling_method = cfg.pooling_method + self.shot_embedding = ShotEmbedding(cfg) + self.encoder = BertEncoder(cfg) + + nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls + self.register_buffer( + 'attention_mask', + self._get_extended_attention_mask( + torch.ones((1, nn_size)).float()), + ) + + def forward( + self, + shot: torch.Tensor, + mask: torch.Tensor = None, + pos_ids: torch.Tensor = None, + pooling_method: str = None, + ): + if self.attention_mask.shape[1] != (shot.shape[1] + 1): + n_shot = shot.shape[1] + 1 # +1 for CLS token + attention_mask = self._get_extended_attention_mask( + torch.ones((1, n_shot), dtype=torch.float, device=shot.device)) + else: + attention_mask = self.attention_mask + + shot_emb = self.shot_embedding(shot, mask=mask, pos_ids=pos_ids) + encoded_emb = self.encoder( + shot_emb, attention_mask=attention_mask).last_hidden_state + + return encoded_emb, self.pooler( + encoded_emb, pooling_method=pooling_method) + + def pooler(self, sequence_output, pooling_method=None): + if pooling_method is None: + pooling_method = self.pooling_method + + if pooling_method == 'cls': + return sequence_output[:, 0, :] + elif pooling_method == 'avg': + return sequence_output[:, 1:].mean(dim=1) + elif pooling_method == 'max': + return sequence_output[:, 1:].max(dim=1)[0] + elif pooling_method == 'center': + cidx = sequence_output.shape[1] // 2 + return sequence_output[:, cidx, :] + else: + raise ValueError + + def _get_extended_attention_mask(self, attention_mask): + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f'Wrong shape for attention_mask (shape {attention_mask.shape})' + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 1905bf39..f97ff8b2 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -9,7 +9,9 @@ if TYPE_CHECKING: from .torch_base_dataset import TorchTaskDataset from .veco_dataset import VecoDataset from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset + from .movie_scene_segmentation import MovieSceneSegmentationDataset from .video_summarization_dataset import VideoSummarizationDataset + else: _import_structure = { 'base': ['TaskDataset'], @@ -19,6 +21,7 @@ else: 'image_instance_segmentation_coco_dataset': ['ImageInstanceSegmentationCocoDataset'], 'video_summarization_dataset': ['VideoSummarizationDataset'], + 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..e56039ac --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py @@ -0,0 +1 @@ +from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py new file mode 100644 index 00000000..925d6281 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py @@ -0,0 +1,173 @@ +# --------------------------------------------------------------------------------------------------- +# The implementation is built upon BaSSL, publicly available at https://github.com/kakaobrain/bassl +# --------------------------------------------------------------------------------------------------- +import copy +import os +import os.path as osp +import random + +import json +import torch +from torchvision.datasets.folder import pil_loader + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from . import sampler + +DATASET_STRUCTURE = { + 'train': { + 'annotation': 'anno/train.json', + 'images': 'keyf_240p', + 'feat': 'feat' + }, + 'test': { + 'annotation': 'anno/test.json', + 'images': 'keyf_240p', + 'feat': 'feat' + } +} + + +@TASK_DATASETS.register_module( + Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) +class MovieSceneSegmentationDataset(TorchTaskDataset): + """dataset for movie scene segmentation. + + Args: + split_config (dict): Annotation file path. {"train":"xxxxx"} + data_root (str, optional): Data root for ``ann_file``, + ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. + test_mode (bool, optional): If set True, annotation will not be loaded. + """ + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + + self.split = next(iter(split_config.keys())) + self.preprocessor = kwargs['preprocessor'] + + self.ann_file = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['annotation']) + self.img_prefix = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['images']) + self.feat_prefix = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['feat']) + + self.test_mode = kwargs['test_mode'] + if self.test_mode: + self.preprocessor.eval() + else: + self.preprocessor.train() + + self.cfg = kwargs.pop('cfg', None) + + self.num_keyframe = self.cfg.num_keyframe if self.cfg is not None else 3 + self.use_single_keyframe = self.cfg.use_single_keyframe if self.cfg is not None else False + + self.load_data() + self.init_sampler(self.cfg) + + def __len__(self): + """Total number of samples of data.""" + return len(self.anno_data) + + def __getitem__(self, idx: int): + data = self.anno_data[ + idx] # {"video_id", "shot_id", "num_shot", "boundary_label"} + vid, sid = data['video_id'], data['shot_id'] + num_shot = data['num_shot'] + + shot_idx = self.shot_sampler(int(sid), num_shot) + + video = self.load_shot_list(vid, shot_idx) + if self.preprocessor is None: + video = torch.stack(video, dim=0) + video = video.view(-1, self.num_keyframe, 3, 224, 224) + else: + video = self.preprocessor(video) + + payload = { + 'idx': idx, + 'vid': vid, + 'sid': sid, + 'video': video, + 'label': abs(data['boundary_label']), # ignore -1 label. + } + return payload + + def load_data(self): + self.tmpl = '{}/shot_{}_img_{}.jpg' # video_id, shot_id, shot_num + + if not self.test_mode: + with open(self.ann_file) as f: + self.anno_data = json.load(f) + self.vidsid2label = { + f"{it['video_id']}_{it['shot_id']}": it['boundary_label'] + for it in self.anno_data + } + else: + with open(self.ann_file) as f: + self.anno_data = json.load(f) + + def init_sampler(self, cfg): + # shot sampler + if cfg is not None: + self.sampling_method = cfg.sampling_method.name + sampler_args = copy.deepcopy( + cfg.sampling_method.params.get(self.sampling_method, {})) + if self.sampling_method == 'instance': + self.shot_sampler = sampler.InstanceShotSampler() + elif self.sampling_method == 'temporal': + self.shot_sampler = sampler.TemporalShotSampler(**sampler_args) + elif self.sampling_method == 'shotcol': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'bassl': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'bassl+shotcol': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'sbd': + self.shot_sampler = sampler.NeighborShotSampler(**sampler_args) + else: + raise NotImplementedError + else: + self.shot_sampler = sampler.NeighborShotSampler() + + def load_shot_list(self, vid, shot_idx): + shot_list = [] + cache = {} + for sidx in shot_idx: + vidsid = f'{vid}_{sidx:04d}' + if vidsid in cache: + shot = cache[vidsid] + else: + shot_path = os.path.join( + self.img_prefix, self.tmpl.format(vid, f'{sidx:04d}', + '{}')) + shot = self.load_shot_keyframes(shot_path) + cache[vidsid] = shot + shot_list.extend(shot) + return shot_list + + def load_shot_keyframes(self, path): + shot = None + if not self.test_mode and self.use_single_keyframe: + # load one randomly sampled keyframe + shot = [ + pil_loader( + path.format(random.randint(0, self.num_keyframe - 1))) + ] + else: + # load all keyframes + shot = [ + pil_loader(path.format(i)) for i in range(self.num_keyframe) + ] + assert shot is not None + return shot diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py new file mode 100644 index 00000000..0fc2fe0f --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +import random + +import numpy as np + + +class InstanceShotSampler: + """ This is for instance at pre-training stage """ + + def __call__(self, center_sid: int, *args, **kwargs): + return center_sid + + +class TemporalShotSampler: + """ This is for temporal at pre-training stage """ + + def __init__(self, neighbor_size: int): + self.N = neighbor_size + + def __call__(self, center_sid: int, total_num_shot: int): + """ we randomly sample one shot from neighbor shots within local temporal window + """ + shot_idx = center_sid + np.arange( + -self.N, self.N + 1 + ) # total number of neighbor shots = 2N+1 (query (1) + neighbors (2*N)) + shot_idx = np.clip(shot_idx, 0, + total_num_shot) # deal with out-of-boundary indices + shot_idx = random.choice( + np.unique(np.delete(shot_idx, np.where(shot_idx == center_sid)))) + return shot_idx + + +class SequenceShotSampler: + """ This is for bassl or shotcol at pre-training stage """ + + def __init__(self, neighbor_size: int, neighbor_interval: int): + self.interval = neighbor_interval + self.window_size = neighbor_size * self.interval # temporal coverage + + def __call__(self, + center_sid: int, + total_num_shot: int, + sparse_method: str = 'edge'): + """ + Args: + center_sid: index of center shot + total_num_shot: last index of shot for given video + sparse_stride: stride to sample sparse ones from dense sequence + for curriculum learning + """ + + dense_shot_idx = center_sid + np.arange( + -self.window_size, self.window_size + 1, + self.interval) # total number of shots = 2*neighbor_size+1 + + if dense_shot_idx[0] < 0: + # if center_sid is near left-side of video, we shift window rightward + # so that the leftmost index is 0 + dense_shot_idx -= dense_shot_idx[0] + elif dense_shot_idx[-1] > (total_num_shot - 1): + # if center_sid is near right-side of video, we shift window leftward + # so that the rightmost index is total_num_shot - 1 + dense_shot_idx -= dense_shot_idx[-1] - (total_num_shot - 1) + + # to deal with videos that have smaller number of shots than window size + dense_shot_idx = np.clip(dense_shot_idx, 0, total_num_shot) + + if sparse_method == 'edge': + # in this case, we use two edge shots as sparse sequence + sparse_stride = len(dense_shot_idx) - 1 + sparse_idx_to_dense = np.arange(0, len(dense_shot_idx), + sparse_stride) + elif sparse_method == 'edge+center': + # in this case, we use two edge shots + center shot as sparse sequence + sparse_idx_to_dense = np.array( + [0, len(dense_shot_idx) - 1, + len(dense_shot_idx) // 2]) + + shot_idx = [sparse_idx_to_dense, dense_shot_idx] + return shot_idx + + +class NeighborShotSampler: + """ This is for scene boundary detection (sbd), i.e., fine-tuning stage """ + + def __init__(self, neighbor_size: int = 8): + self.neighbor_size = neighbor_size + + def __call__(self, center_sid: int, total_num_shot: int): + # total number of shots = 2 * neighbor_size + 1 + shot_idx = center_sid + np.arange(-self.neighbor_size, + self.neighbor_size + 1) + shot_idx = np.clip(shot_idx, 0, + total_num_shot) # for out-of-boundary indices + + return shot_idx diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 1c42a5f3..7c0e08dc 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -35,6 +35,8 @@ class OutputKeys(object): UUID = 'uuid' WORD = 'word' KWS_LIST = 'kws_list' + SPLIT_VIDEO_NUM = 'split_video_num' + SPLIT_META_DICT = 'split_meta_dict' TASK_OUTPUTS = { @@ -241,6 +243,22 @@ TASK_OUTPUTS = { # } Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], + # movide scene segmentation result for a single video + # { + # "split_video_num":3, + # "split_meta_dict": + # { + # scene_id: + # { + # "shot": [0,1,2], + # "frame": [start_frame, end_frame] + # } + # } + # + # } + Tasks.movie_scene_segmentation: + [OutputKeys.SPLIT_VIDEO_NUM, OutputKeys.SPLIT_META_DICT], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 53f55b06..943578fb 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -144,6 +144,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_vitb_video-single-object-tracking_ostrack'), Tasks.image_reid_person: (Pipelines.image_reid_person, 'damo/cv_passvitb_image-reid-person_market'), + Tasks.movie_scene_segmentation: + (Pipelines.movie_scene_segmentation, + 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 640ffd4c..bd175578 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -42,6 +42,8 @@ if TYPE_CHECKING: from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline + from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline + else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -90,7 +92,9 @@ else: 'video_category_pipeline': ['VideoCategoryPipeline'], 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], 'easycv_pipeline': - ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'] + ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], + 'movie_scene_segmentation_pipeline': + ['MovieSceneSegmentationPipeline'], } import sys diff --git a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py new file mode 100644 index 00000000..0ef0261d --- /dev/null +++ b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py @@ -0,0 +1,67 @@ +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.movie_scene_segmentation, + module_name=Pipelines.movie_scene_segmentation) +class MovieSceneSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a movie scene segmentation pipeline for prediction + + Args: + model: model id on modelscope hub + """ + _device = kwargs.pop('device', 'gpu') + if torch.cuda.is_available() and _device == 'gpu': + device = 'gpu' + else: + device = 'cpu' + super().__init__(model=model, device=device, **kwargs) + + logger.info('Load model done!') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ use pyscenedetect to detect shot from the input video, and generate key-frame jpg, anno.ndjson, and shot-frame.txt + Then use shot-encoder to encoder feat of the detected key-frame + + Args: + input: path of the input video + + """ + self.input_video_pth = input + if isinstance(input, str): + shot_feat, sid = self.model.preprocess(input) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + + result = {'sid': sid, 'shot_feat': shot_feat} + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + output = self.model.inference(input) + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + data = {'input_video_pth': self.input_video_pth, 'feat': inputs} + video_num, meta_dict = self.model.postprocess(data) + result = { + OutputKeys.SPLIT_VIDEO_NUM: video_num, + OutputKeys.SPLIT_META_DICT: meta_dict + } + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index f5ac0e4e..d365b6fa 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor) - from .video import ReadVideoData + from .video import ReadVideoData, MovieSceneSegmentationPreprocessor from .star import ConversationalTextToSqlPreprocessor else: @@ -37,7 +37,7 @@ else: 'common': ['Compose', 'ToTensor', 'Filter'], 'audio': ['LinearAECAndFbank'], 'asr': ['WavToScp'], - 'video': ['ReadVideoData'], + 'video': ['ReadVideoData', 'MovieSceneSegmentationPreprocessor'], 'image': [ 'LoadImage', 'load_image', 'ImageColorEnhanceFinetunePreprocessor', 'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' diff --git a/modelscope/preprocessors/movie_scene_segmentation/__init__.py b/modelscope/preprocessors/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..73da792d --- /dev/null +++ b/modelscope/preprocessors/movie_scene_segmentation/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .transforms import get_transform +else: + _import_structure = { + 'transforms': ['get_transform'], + } + + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/movie_scene_segmentation/transforms.py b/modelscope/preprocessors/movie_scene_segmentation/transforms.py new file mode 100644 index 00000000..b4e57420 --- /dev/null +++ b/modelscope/preprocessors/movie_scene_segmentation/transforms.py @@ -0,0 +1,312 @@ +# ------------------------------------------------------------------------------------ +# The codes below partially refer to the BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ +import numbers +import os.path as osp +import random +from typing import List + +import numpy as np +import torch +import torchvision.transforms as TF +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter + + +def get_transform(lst): + assert len(lst) > 0 + transform_lst = [] + for item in lst: + transform_lst.append(build_transform(item)) + transform = TF.Compose(transform_lst) + return transform + + +def build_transform(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + + if type == 'VideoResizedCenterCrop': + return VideoResizedCenterCrop(**cfg) + elif type == 'VideoToTensor': + return VideoToTensor(**cfg) + elif type == 'VideoRandomResizedCrop': + return VideoRandomResizedCrop(**cfg) + elif type == 'VideoRandomHFlip': + return VideoRandomHFlip() + elif type == 'VideoRandomColorJitter': + return VideoRandomColorJitter(**cfg) + elif type == 'VideoRandomGaussianBlur': + return VideoRandomGaussianBlur(**cfg) + else: + raise NotImplementedError + + +class VideoResizedCenterCrop(torch.nn.Module): + + def __init__(self, image_size, crop_size): + self.tfm = TF.Compose([ + TF.Resize(size=image_size, interpolation=Image.BICUBIC), + TF.CenterCrop(crop_size), + ]) + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + return [self.tfm(img) for img in imgmap] + + +class VideoToTensor(torch.nn.Module): + + def __init__(self, mean=None, std=None, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + assert self.mean is not None + assert self.std is not None + + def __to_tensor__(self, img): + return F.to_tensor(img) + + def __normalize__(self, img): + return F.normalize(img, self.mean, self.std, self.inplace) + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + return [self.__normalize__(self.__to_tensor__(img)) for img in imgmap] + + +class VideoRandomResizedCrop(torch.nn.Module): + + def __init__(self, size, bottom_area=0.2): + self.p = 1.0 + self.interpolation = Image.BICUBIC + self.size = size + self.bottom_area = bottom_area + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.p: # do RandomResizedCrop, consistent=True + top, left, height, width = TF.RandomResizedCrop.get_params( + imgmap[0], + scale=(self.bottom_area, 1.0), + ratio=(3 / 4.0, 4 / 3.0)) + return [ + F.resized_crop( + img=img, + top=top, + left=left, + height=height, + width=width, + size=(self.size, self.size), + ) for img in imgmap + ] + else: + return [ + F.resize(img=img, size=[self.size, self.size]) + for img in imgmap + ] + + +class VideoRandomHFlip(torch.nn.Module): + + def __init__(self, consistent=True, command=None, seq_len=0): + self.consistent = consistent + if seq_len != 0: + self.consistent = False + if command == 'left': + self.threshold = 0 + elif command == 'right': + self.threshold = 1 + else: + self.threshold = 0.5 + self.seq_len = seq_len + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if self.consistent: + if random.random() < self.threshold: + return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] + else: + return imgmap + else: + result = [] + for idx, i in enumerate(imgmap): + if idx % self.seq_len == 0: + th = random.random() + if th < self.threshold: + result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + +class VideoRandomColorJitter(torch.nn.Module): + """Randomly change the brightness, contrast and saturation of an image. + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__( + self, + brightness=0, + contrast=0, + saturation=0, + hue=0, + consistent=True, + p=1.0, + seq_len=0, + ): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input( + hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.consistent = consistent + self.threshold = p + self.seq_len = seq_len + + def _check_input(self, + value, + name, + center=1, + bound=(0, float('inf')), + clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + 'If {} is a single number, it must be non negative.'. + format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError('{} values should be between {}'.format( + name, bound)) + else: + raise TypeError( + '{} should be a single number or a list/tuple with lenght 2.'. + format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append( + TF.Lambda( + lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append( + TF.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append( + TF.Lambda( + lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append( + TF.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = TF.Compose(transforms) + + return transform + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.threshold: # do ColorJitter + if self.consistent: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + + return [transform(i) for i in imgmap] + else: + if self.seq_len == 0: + return [ + self.get_params(self.brightness, self.contrast, + self.saturation, self.hue)(img) + for img in imgmap + ] + else: + result = [] + for idx, img in enumerate(imgmap): + if idx % self.seq_len == 0: + transform = self.get_params( + self.brightness, + self.contrast, + self.saturation, + self.hue, + ) + result.append(transform(img)) + return result + + else: + return imgmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + + +class VideoRandomGaussianBlur(torch.nn.Module): + + def __init__(self, radius_min=0.1, radius_max=2.0, p=0.5): + self.radius_min = radius_min + self.radius_max = radius_max + self.p = p + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.p: + result = [] + for _, img in enumerate(imgmap): + _radius = random.uniform(self.radius_min, self.radius_max) + result.append( + img.filter(ImageFilter.GaussianBlur(radius=_radius))) + return result + else: + return imgmap + + +def apply_transform(images, trans): + return torch.stack(trans(images), dim=0) diff --git a/modelscope/preprocessors/video.py b/modelscope/preprocessors/video.py index 36110d1b..0d2e8c3e 100644 --- a/modelscope/preprocessors/video.py +++ b/modelscope/preprocessors/video.py @@ -9,6 +9,12 @@ import torchvision.transforms._transforms_video as transforms from decord import VideoReader from torchvision.transforms import Compose +from modelscope.metainfo import Preprocessors +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .base import Preprocessor +from .builder import PREPROCESSORS + def ReadVideoData(cfg, video_path): """ simple interface to load video frames from file @@ -227,3 +233,42 @@ class KineticsResizedCrop(object): def __call__(self, clip): return self._get_controlled_crop(clip) + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor) +class MovieSceneSegmentationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """ + movie scene segmentation preprocessor + """ + super().__init__(*args, **kwargs) + + self.is_train = kwargs.pop('is_train', True) + self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None) + self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None) + self.num_keyframe = kwargs.pop('num_keyframe', 3) + + from .movie_scene_segmentation import get_transform + self.train_transform = get_transform(self.preprocessor_train_cfg) + self.test_transform = get_transform(self.preprocessor_test_cfg) + + def train(self): + self.is_train = True + return + + def eval(self): + self.is_train = False + return + + @type_assert(object, object) + def __call__(self, results): + if self.is_train: + transforms = self.train_transform + else: + transforms = self.test_transform + + results = torch.stack(transforms(results), dim=0) + results = results.view(-1, self.num_keyframe, 3, 224, 224) + return results diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index 32ff674f..8f8938c8 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -8,7 +8,8 @@ if TYPE_CHECKING: from .base import DummyTrainer from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, - ImagePortraitEnhancementTrainer) + ImagePortraitEnhancementTrainer, + MovieSceneSegmentationTrainer) from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer @@ -21,7 +22,7 @@ else: 'builder': ['build_trainer'], 'cv': [ 'ImageInstanceSegmentationTrainer', - 'ImagePortraitEnhancementTrainer' + 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' ], 'multi_modal': ['CLIPTrainer'], 'nlp': ['SequenceClassificationTrainer'], diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index 99c2aea5..4c65870e 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .image_instance_segmentation_trainer import \ ImageInstanceSegmentationTrainer from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer + from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer else: _import_structure = { @@ -14,6 +15,7 @@ else: ['ImageInstanceSegmentationTrainer'], 'image_portrait_enhancement_trainer': ['ImagePortraitEnhancementTrainer'], + 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] } import sys diff --git a/modelscope/trainers/cv/movie_scene_segmentation_trainer.py b/modelscope/trainers/cv/movie_scene_segmentation_trainer.py new file mode 100644 index 00000000..ee4dd849 --- /dev/null +++ b/modelscope/trainers/cv/movie_scene_segmentation_trainer.py @@ -0,0 +1,20 @@ +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer + + +@TRAINERS.register_module(module_name=Trainers.movie_scene_segmentation) +class MovieSceneSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index dae7117e..a9d1345d 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -62,6 +62,7 @@ class CVTasks(object): video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' crowd_counting = 'crowd-counting' + movie_scene_segmentation = 'movie-scene-segmentation' # reid and tracking video_single_object_tracking = 'video-single-object-tracking' diff --git a/requirements/cv.txt b/requirements/cv.txt index b7b3e4e8..ebb61851 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -21,6 +21,7 @@ regex scikit-image>=0.19.3 scikit-learn>=0.20.1 shapely +shotdetect_scenedetect_lgss tensorflow-estimator>=1.15.1 tf_slim timm>=0.4.9 diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index ed07def7..9780ac4b 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -31,6 +31,12 @@ class ImgPreprocessor(Preprocessor): class MsDatasetTest(unittest.TestCase): + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_movie_scene_seg_toydata(self): + ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train') + print(ms_ds_train._hf_ds.config_kwargs) + assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_coco(self): ms_ds_train = MsDataset.load( diff --git a/tests/pipelines/test_movie_scene_segmentation.py b/tests/pipelines/test_movie_scene_segmentation.py new file mode 100644 index 00000000..5993c634 --- /dev/null +++ b/tests/pipelines/test_movie_scene_segmentation.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MovieSceneSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_movie_scene_segmentation(self): + input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' + model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + movie_scene_segmentation_pipeline = pipeline( + Tasks.movie_scene_segmentation, model=model_id) + result = movie_scene_segmentation_pipeline(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_movie_scene_segmentation_with_default_task(self): + input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' + movie_scene_segmentation_pipeline = pipeline( + Tasks.movie_scene_segmentation) + result = movie_scene_segmentation_pipeline(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_movie_scene_segmentation_trainer.py b/tests/trainers/test_movie_scene_segmentation_trainer.py new file mode 100644 index 00000000..f25dc92a --- /dev/null +++ b/tests/trainers/test_movie_scene_segmentation_trainer.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.movie_scene_segmentation import \ + MovieSceneSegmentationModel +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='train', + cfg=cfg.preprocessor, + test_mode=False) + + test_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='test', + cfg=cfg.preprocessor, + test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + namespace=train_data_cfg.namespace, + cfg=train_data_cfg.cfg, + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + namespace=test_data_cfg.namespace, + cfg=test_data_cfg.cfg, + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = MovieSceneSegmentationModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main()