Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9872869master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:59fa397b01dc4c9b67a19ca42f149287b9c4e7b2158aba5d07d2db88af87b23f | |||
| size 126815483 | |||
| @@ -27,6 +27,7 @@ class Models(object): | |||
| video_summarization = 'pgl-video-summarization' | |||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| resnet50_bert = 'resnet50-bert' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| @@ -133,6 +134,7 @@ class Pipelines(object): | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| image_semantic_segmentation = 'image-semantic-segmentation' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -195,6 +197,7 @@ class Trainers(object): | |||
| image_instance_segmentation = 'image-instance-segmentation' | |||
| image_portrait_enhancement = 'image-portrait-enhancement' | |||
| video_summarization = 'video-summarization' | |||
| movie_scene_segmentation = 'movie-scene-segmentation' | |||
| # nlp trainers | |||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
| @@ -223,6 +226,7 @@ class Preprocessors(object): | |||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||
| image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | |||
| video_summarization_preprocessor = 'video-summarization-preprocessor' | |||
| movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' | |||
| # nlp preprocessor | |||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
| @@ -279,6 +283,8 @@ class Metrics(object): | |||
| # metrics for image-portrait-enhancement task | |||
| image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' | |||
| video_summarization_metric = 'video-summarization-metric' | |||
| # metric for movie-scene-segmentation task | |||
| movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | |||
| class Optimizers(object): | |||
| @@ -16,6 +16,7 @@ if TYPE_CHECKING: | |||
| from .text_generation_metric import TextGenerationMetric | |||
| from .token_classification_metric import TokenClassificationMetric | |||
| from .video_summarization_metric import VideoSummarizationMetric | |||
| from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | |||
| else: | |||
| _import_structure = { | |||
| @@ -32,6 +33,7 @@ else: | |||
| 'text_generation_metric': ['TextGenerationMetric'], | |||
| 'token_classification_metric': ['TokenClassificationMetric'], | |||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | |||
| 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | |||
| } | |||
| import sys | |||
| @@ -34,6 +34,7 @@ task_default_metrics = { | |||
| Tasks.video_summarization: [Metrics.video_summarization_metric], | |||
| Tasks.image_captioning: [Metrics.text_gen_metric], | |||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| from typing import Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.utils.registry import default_group | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module( | |||
| group_key=default_group, | |||
| module_name=Metrics.movie_scene_segmentation_metric) | |||
| class MovieSceneSegmentationMetric(Metric): | |||
| """The metric computation class for movie scene segmentation classes. | |||
| """ | |||
| def __init__(self): | |||
| self.preds = [] | |||
| self.labels = [] | |||
| self.eps = 1e-5 | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| preds = outputs['pred'] | |||
| labels = inputs['label'] | |||
| self.preds.extend(preds) | |||
| self.labels.extend(labels) | |||
| def evaluate(self): | |||
| gts = np.array(torch_nested_numpify(torch_nested_detach(self.labels))) | |||
| prob = np.array(torch_nested_numpify(torch_nested_detach(self.preds))) | |||
| gt_one = gts == 1 | |||
| gt_zero = gts == 0 | |||
| pred_one = prob == 1 | |||
| pred_zero = prob == 0 | |||
| tp = (gt_one * pred_one).sum() | |||
| fp = (gt_zero * pred_one).sum() | |||
| fn = (gt_one * pred_zero).sum() | |||
| precision = 100.0 * tp / (tp + fp + self.eps) | |||
| recall = 100.0 * tp / (tp + fn + self.eps) | |||
| f1 = 2 * precision * recall / (precision + recall) | |||
| return { | |||
| MetricKeys.F1: f1, | |||
| MetricKeys.RECALL: recall, | |||
| MetricKeys.PRECISION: precision | |||
| } | |||
| @@ -9,8 +9,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, | |||
| realtime_object_detection, salient_detection, super_resolution, | |||
| movie_scene_segmentation, object_detection, | |||
| product_retrieval_embedding, realtime_object_detection, | |||
| salient_detection, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| # yapf: enable | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .model import MovieSceneSegmentationModel | |||
| from .datasets import MovieSceneSegmentationDataset | |||
| else: | |||
| _import_structure = { | |||
| 'model': ['MovieSceneSegmentationModel'], | |||
| 'datasets': ['MovieSceneSegmentationDataset'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,45 @@ | |||
| # ------------------------------------------------------------------------------------ | |||
| # BaSSL | |||
| # Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |||
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |||
| # Github: https://github.com/kakaobrain/bassl | |||
| # ------------------------------------------------------------------------------------ | |||
| from .utils.shot_encoder import resnet50 | |||
| from .utils.trn import TransformerCRN | |||
| def get_shot_encoder(cfg): | |||
| name = cfg['model']['shot_encoder']['name'] | |||
| shot_encoder_args = cfg['model']['shot_encoder'][name] | |||
| if name == 'resnet': | |||
| depth = shot_encoder_args['depth'] | |||
| if depth == 50: | |||
| shot_encoder = resnet50(**shot_encoder_args['params'], ) | |||
| else: | |||
| raise NotImplementedError | |||
| else: | |||
| raise NotImplementedError | |||
| return shot_encoder | |||
| def get_contextual_relation_network(cfg): | |||
| crn = None | |||
| if cfg['model']['contextual_relation_network']['enabled']: | |||
| name = cfg['model']['contextual_relation_network']['name'] | |||
| crn_args = cfg['model']['contextual_relation_network']['params'][name] | |||
| if name == 'trn': | |||
| sampling_name = cfg['model']['loss']['sampling_method']['name'] | |||
| crn_args['neighbor_size'] = ( | |||
| 2 * cfg['model']['loss']['sampling_method']['params'] | |||
| [sampling_name]['neighbor_size']) | |||
| crn = TransformerCRN(crn_args) | |||
| else: | |||
| raise NotImplementedError | |||
| return crn | |||
| __all__ = ['get_shot_encoder', 'get_contextual_relation_network'] | |||
| @@ -0,0 +1,192 @@ | |||
| import os | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import einops | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchvision.transforms as TF | |||
| from PIL import Image | |||
| from shotdetect_scenedetect_lgss import shot_detect | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .get_model import get_contextual_relation_network, get_shot_encoder | |||
| from .utils.save_op import get_pred_boundary, pred2scene, scene2video | |||
| logger = get_logger() | |||
| @MODELS.register_module( | |||
| Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) | |||
| class MovieSceneSegmentationModel(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """str -- model file root.""" | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| params = torch.load(model_path, map_location='cpu') | |||
| config_path = osp.join(model_dir, ModelFile.CONFIGURATION) | |||
| self.cfg = Config.from_file(config_path) | |||
| def load_param_with_prefix(prefix, model, src_params): | |||
| own_state = model.state_dict() | |||
| for name, param in own_state.items(): | |||
| src_name = prefix + '.' + name | |||
| own_state[name] = src_params[src_name] | |||
| model.load_state_dict(own_state) | |||
| self.shot_encoder = get_shot_encoder(self.cfg) | |||
| load_param_with_prefix('shot_encoder', self.shot_encoder, params) | |||
| self.crn = get_contextual_relation_network(self.cfg) | |||
| load_param_with_prefix('crn', self.crn, params) | |||
| crn_name = self.cfg.model.contextual_relation_network.name | |||
| hdim = self.cfg.model.contextual_relation_network.params[crn_name][ | |||
| 'hidden_size'] | |||
| self.head_sbd = nn.Linear(hdim, 2) | |||
| load_param_with_prefix('head_sbd', self.head_sbd, params) | |||
| self.test_transform = TF.Compose([ | |||
| TF.Resize(size=256, interpolation=Image.BICUBIC), | |||
| TF.CenterCrop(224), | |||
| TF.ToTensor(), | |||
| TF.Normalize( | |||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
| ]) | |||
| self.infer_result = {'vid': [], 'sid': [], 'pred': []} | |||
| sampling_method = self.cfg.dataset.sampling_method.name | |||
| self.neighbor_size = self.cfg.dataset.sampling_method.params[ | |||
| sampling_method].neighbor_size | |||
| self.eps = 1e-5 | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |||
| data = inputs['video'] | |||
| labels = inputs['label'] | |||
| outputs = self.shared_step(data) | |||
| loss = F.cross_entropy( | |||
| outputs.squeeze(), labels.squeeze(), reduction='none') | |||
| lpos = labels == 1 | |||
| lneg = labels == 0 | |||
| pp, nn = 1, 1 | |||
| wp = (pp / float(pp + nn)) * lpos / (lpos.sum() + self.eps) | |||
| wn = (nn / float(pp + nn)) * lneg / (lneg.sum() + self.eps) | |||
| w = wp + wn | |||
| loss = (w * loss).sum() | |||
| probs = torch.argmax(outputs, dim=1) | |||
| re = dict(pred=probs, loss=loss) | |||
| return re | |||
| def inference(self, batch): | |||
| logger.info('Begin scene detect ......') | |||
| bs = self.cfg.pipeline.batch_size_per_gpu | |||
| sids = batch['sid'] | |||
| inputs = batch['shot_feat'] | |||
| shot_num = len(sids) | |||
| cnt = shot_num // bs + 1 | |||
| for i in range(cnt): | |||
| start = i * bs | |||
| end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num | |||
| input_ = inputs[start:end] | |||
| sid_ = sids[start:end] | |||
| input_ = torch.stack(input_) | |||
| outputs = self.shared_step(input_) # shape [b,2] | |||
| prob = F.softmax(outputs, dim=1) | |||
| self.infer_result['sid'].extend(sid_.cpu().detach().numpy()) | |||
| self.infer_result['pred'].extend(prob[:, 1].cpu().detach().numpy()) | |||
| self.infer_result['pred'] = np.stack(self.infer_result['pred']) | |||
| assert len(self.infer_result['sid']) == len(sids) | |||
| assert len(self.infer_result['pred']) == len(inputs) | |||
| return self.infer_result | |||
| def shared_step(self, inputs): | |||
| with torch.no_grad(): | |||
| # infer shot encoder | |||
| shot_repr = self.extract_shot_representation(inputs) | |||
| assert len(shot_repr.shape) == 3 | |||
| # infer CRN | |||
| _, pooled = self.crn(shot_repr, mask=None) | |||
| # infer boundary score | |||
| pred = self.head_sbd(pooled) | |||
| return pred | |||
| def save_shot_feat(self, _repr): | |||
| feat = _repr.float().cpu().numpy() | |||
| pth = self.cfg.dataset.img_path + '/features' | |||
| os.makedirs(pth) | |||
| for idx in range(_repr.shape[0]): | |||
| name = f'shot_{str(idx).zfill(4)}.npy' | |||
| name = osp.join(pth, name) | |||
| np.save(name, feat[idx]) | |||
| def extract_shot_representation(self, | |||
| inputs: torch.Tensor) -> torch.Tensor: | |||
| """ inputs [b s k c h w] -> output [b d] """ | |||
| assert len(inputs.shape) == 6 # (B Shot Keyframe C H W) | |||
| b, s, k, c, h, w = inputs.shape | |||
| inputs = einops.rearrange(inputs, 'b s k c h w -> (b s) k c h w', s=s) | |||
| keyframe_repr = [self.shot_encoder(inputs[:, _k]) for _k in range(k)] | |||
| # [k (b s) d] -> [(b s) d] | |||
| shot_repr = torch.stack(keyframe_repr).mean(dim=0) | |||
| shot_repr = einops.rearrange(shot_repr, '(b s) d -> b s d', s=s) | |||
| return shot_repr | |||
| def postprocess(self, inputs: Dict[str, Any], **kwargs): | |||
| logger.info('Generate scene .......') | |||
| pred_dict = inputs['feat'] | |||
| thres = self.cfg.pipeline.save_threshold | |||
| anno_dict = get_pred_boundary(pred_dict, thres) | |||
| scene_dict, scene_list = pred2scene(self.shot2keyf, anno_dict) | |||
| if self.cfg.pipeline.save_split_scene: | |||
| re_dir = scene2video(inputs['input_video_pth'], scene_list, thres) | |||
| print(f'Split scene video saved to {re_dir}') | |||
| return len(scene_list), scene_dict | |||
| def preprocess(self, inputs): | |||
| logger.info('Begin shot detect......') | |||
| shot_keyf_lst, anno, shot2keyf = shot_detect( | |||
| inputs, **self.cfg.preprocessor.shot_detect) | |||
| logger.info('Shot detect done!') | |||
| single_shot_feat, sid = [], [] | |||
| for idx, one_shot in enumerate(shot_keyf_lst): | |||
| one_shot = [ | |||
| self.test_transform(one_frame) for one_frame in one_shot | |||
| ] | |||
| one_shot = torch.stack(one_shot, dim=0) | |||
| single_shot_feat.append(one_shot) | |||
| sid.append(idx) | |||
| single_shot_feat = torch.stack(single_shot_feat, dim=0) | |||
| shot_feat = [] | |||
| for idx, one_shot in enumerate(anno): | |||
| shot_idx = int(one_shot['shot_id']) + np.arange( | |||
| -self.neighbor_size, self.neighbor_size + 1) | |||
| shot_idx = np.clip(shot_idx, 0, one_shot['num_shot']) | |||
| _one_shot = single_shot_feat[shot_idx] | |||
| shot_feat.append(_one_shot) | |||
| self.shot2keyf = shot2keyf | |||
| self.anno = anno | |||
| return shot_feat, sid | |||
| @@ -0,0 +1,3 @@ | |||
| from .save_op import get_pred_boundary, pred2scene, scene2video | |||
| from .shot_encoder import resnet50 | |||
| from .trn import TransformerCRN | |||
| @@ -0,0 +1,29 @@ | |||
| # ------------------------------------------------------------------------------------ | |||
| # BaSSL | |||
| # Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |||
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |||
| # Github: https://github.com/kakaobrain/bassl | |||
| # ------------------------------------------------------------------------------------ | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class MlpHead(nn.Module): | |||
| def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): | |||
| super().__init__() | |||
| self.output_dim = output_dim | |||
| self.input_dim = input_dim | |||
| self.hidden_dim = hidden_dim | |||
| self.model = nn.Sequential( | |||
| nn.Linear(self.input_dim, self.hidden_dim, bias=True), | |||
| nn.ReLU(), | |||
| nn.Linear(self.hidden_dim, self.output_dim, bias=True), | |||
| ) | |||
| def forward(self, x): | |||
| # x shape: [b t d] where t means the number of views | |||
| x = self.model(x) | |||
| return F.normalize(x, dim=-1) | |||
| @@ -0,0 +1,118 @@ | |||
| # ---------------------------------------------------------------------------------- | |||
| # The codes below partially refer to the SceneSeg LGSS. | |||
| # Github: https://github.com/AnyiRao/SceneSeg | |||
| # ---------------------------------------------------------------------------------- | |||
| import os | |||
| import os.path as osp | |||
| import subprocess | |||
| import cv2 | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| def get_pred_boundary(pred_dict, threshold=0.5): | |||
| pred = pred_dict['pred'] | |||
| tmp = (pred > threshold).astype(np.int32) | |||
| anno_dict = {} | |||
| for idx in range(len(tmp)): | |||
| anno_dict.update({str(pred_dict['sid'][idx]).zfill(4): int(tmp[idx])}) | |||
| return anno_dict | |||
| def pred2scene(shot2keyf, anno_dict): | |||
| scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict) | |||
| scene_dict = {} | |||
| assert len(scene_list) == len(pair_list) | |||
| for scene_ind, scene_item in enumerate(scene_list): | |||
| scene_dict.update( | |||
| {scene_ind: { | |||
| 'shot': pair_list[scene_ind], | |||
| 'frame': scene_item | |||
| }}) | |||
| return scene_dict, scene_list | |||
| def scene2video(source_movie_fn, scene_list, thres): | |||
| vcap = cv2.VideoCapture(source_movie_fn) | |||
| fps = vcap.get(cv2.CAP_PROP_FPS) # video.fps | |||
| out_video_dir_fn = os.path.join(os.getcwd(), | |||
| f'pred_result/scene_video_{thres}') | |||
| os.makedirs(out_video_dir_fn, exist_ok=True) | |||
| for scene_ind, scene_item in tqdm(enumerate(scene_list)): | |||
| scene = str(scene_ind).zfill(4) | |||
| start_frame = int(scene_item[0]) | |||
| end_frame = int(scene_item[1]) | |||
| start_time, end_time = start_frame / fps, end_frame / fps | |||
| duration_time = end_time - start_time | |||
| out_video_fn = os.path.join(out_video_dir_fn, | |||
| 'scene_{}.mp4'.format(scene)) | |||
| if os.path.exists(out_video_fn): | |||
| continue | |||
| call_list = ['ffmpeg'] | |||
| call_list += ['-v', 'quiet'] | |||
| call_list += [ | |||
| '-y', '-ss', | |||
| str(start_time), '-t', | |||
| str(duration_time), '-i', source_movie_fn | |||
| ] | |||
| call_list += ['-map_chapters', '-1'] | |||
| call_list += [out_video_fn] | |||
| subprocess.call(call_list) | |||
| return osp.join(os.getcwd(), 'pred_result') | |||
| def get_demo_scene_list(shot2keyf, anno_dict): | |||
| pair_list = get_pair_list(anno_dict) | |||
| scene_list = [] | |||
| for pair in pair_list: | |||
| start_shot, end_shot = int(pair[0]), int(pair[-1]) | |||
| start_frame = shot2keyf[start_shot].split(' ')[0] | |||
| end_frame = shot2keyf[end_shot].split(' ')[1] | |||
| scene_list.append((start_frame, end_frame)) | |||
| return scene_list, pair_list | |||
| def get_pair_list(anno_dict): | |||
| sort_anno_dict_key = sorted(anno_dict.keys()) | |||
| tmp = 0 | |||
| tmp_list = [] | |||
| tmp_label_list = [] | |||
| anno_list = [] | |||
| anno_label_list = [] | |||
| for key in sort_anno_dict_key: | |||
| value = anno_dict.get(key) | |||
| tmp += value | |||
| tmp_list.append(key) | |||
| tmp_label_list.append(value) | |||
| if tmp == 1: | |||
| anno_list.append(tmp_list) | |||
| anno_label_list.append(tmp_label_list) | |||
| tmp = 0 | |||
| tmp_list = [] | |||
| tmp_label_list = [] | |||
| continue | |||
| if key == sort_anno_dict_key[-1]: | |||
| if len(tmp_list) > 0: | |||
| anno_list.append(tmp_list) | |||
| anno_label_list.append(tmp_label_list) | |||
| if len(anno_list) == 0: | |||
| return None | |||
| while [] in anno_list: | |||
| anno_list.remove([]) | |||
| tmp_anno_list = [anno_list[0]] | |||
| pair_list = [] | |||
| for ind in range(len(anno_list) - 1): | |||
| cont_count = int(anno_list[ind + 1][0]) - int(anno_list[ind][-1]) | |||
| if cont_count > 1: | |||
| pair_list.extend(tmp_anno_list) | |||
| tmp_anno_list = [anno_list[ind + 1]] | |||
| continue | |||
| tmp_anno_list.append(anno_list[ind + 1]) | |||
| pair_list.extend(tmp_anno_list) | |||
| return pair_list | |||
| @@ -0,0 +1,331 @@ | |||
| """ | |||
| Modified from original implementation in torchvision | |||
| """ | |||
| from typing import Any, Callable, List, Optional, Type, Union | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch import Tensor | |||
| def conv3x3(in_planes: int, | |||
| out_planes: int, | |||
| stride: int = 1, | |||
| groups: int = 1, | |||
| dilation: int = 1) -> nn.Conv2d: | |||
| """3x3 convolution with padding""" | |||
| return nn.Conv2d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| groups=groups, | |||
| bias=False, | |||
| dilation=dilation, | |||
| ) | |||
| def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: | |||
| """1x1 convolution""" | |||
| return nn.Conv2d( | |||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class BasicBlock(nn.Module): | |||
| expansion: int = 1 | |||
| def __init__( | |||
| self, | |||
| inplanes: int, | |||
| planes: int, | |||
| stride: int = 1, | |||
| downsample: Optional[nn.Module] = None, | |||
| groups: int = 1, | |||
| base_width: int = 64, | |||
| dilation: int = 1, | |||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
| ) -> None: | |||
| super(BasicBlock, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError( | |||
| 'BasicBlock only supports groups=1 and base_width=64') | |||
| if dilation > 1: | |||
| raise NotImplementedError( | |||
| 'Dilation > 1 not supported in BasicBlock') | |||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||
| self.bn1 = norm_layer(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = conv3x3(planes, planes) | |||
| self.bn2 = norm_layer(planes) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x: Tensor) -> Tensor: | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |||
| # while original implementation places the stride at the first 1x1 convolution(self.conv1) | |||
| # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |||
| # This variant is also known as ResNet V1.5 and improves accuracy according to | |||
| # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |||
| expansion: int = 4 | |||
| def __init__( | |||
| self, | |||
| inplanes: int, | |||
| planes: int, | |||
| stride: int = 1, | |||
| downsample: Optional[nn.Module] = None, | |||
| groups: int = 1, | |||
| base_width: int = 64, | |||
| dilation: int = 1, | |||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
| ) -> None: | |||
| super(Bottleneck, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| width = int(planes * (base_width / 64.0)) * groups | |||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv1x1(inplanes, width) | |||
| self.bn1 = norm_layer(width) | |||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
| self.bn2 = norm_layer(width) | |||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||
| self.bn3 = norm_layer(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x: Tensor) -> Tensor: | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet(nn.Module): | |||
| def __init__( | |||
| self, | |||
| block: Type[Union[BasicBlock, Bottleneck]], | |||
| layers: List[int], | |||
| in_channel_dim: int = 3, | |||
| zero_init_residual: bool = False, | |||
| use_last_block_grid: bool = False, | |||
| groups: int = 1, | |||
| width_per_group: int = 64, | |||
| replace_stride_with_dilation: Optional[List[bool]] = None, | |||
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |||
| ) -> None: | |||
| super(ResNet, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| self._norm_layer = norm_layer | |||
| self.use_last_block_grid = use_last_block_grid | |||
| self.inplanes = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| # each element in the tuple indicates if we should replace | |||
| # the 2x2 stride with a dilated convolution instead | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError('replace_stride_with_dilation should be None ' | |||
| 'or a 3-element tuple, got {}'.format( | |||
| replace_stride_with_dilation)) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| self.conv1 = nn.Conv2d( | |||
| in_channel_dim, | |||
| self.inplanes, | |||
| kernel_size=7, | |||
| stride=2, | |||
| padding=3, | |||
| bias=False, | |||
| ) | |||
| self.bn1 = norm_layer(self.inplanes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||
| self.layer2 = self._make_layer( | |||
| block, | |||
| 128, | |||
| layers[1], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[0]) | |||
| self.layer3 = self._make_layer( | |||
| block, | |||
| 256, | |||
| layers[2], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[1]) | |||
| self.layer4 = self._make_layer( | |||
| block, | |||
| 512, | |||
| layers[3], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[2]) | |||
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| # Zero-initialize the last BN in each residual branch, | |||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, | |||
| 0) # type: ignore[arg-type] | |||
| elif isinstance(m, BasicBlock): | |||
| nn.init.constant_(m.bn2.weight, | |||
| 0) # type: ignore[arg-type] | |||
| def _make_layer( | |||
| self, | |||
| block: Type[Union[BasicBlock, Bottleneck]], | |||
| planes: int, | |||
| blocks: int, | |||
| stride: int = 1, | |||
| dilate: bool = False, | |||
| ) -> nn.Sequential: | |||
| norm_layer = self._norm_layer | |||
| downsample = None | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||
| norm_layer(planes * block.expansion), | |||
| ) | |||
| layers = [] | |||
| layers.append( | |||
| block( | |||
| self.inplanes, | |||
| planes, | |||
| stride, | |||
| downsample, | |||
| self.groups, | |||
| self.base_width, | |||
| previous_dilation, | |||
| norm_layer, | |||
| )) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append( | |||
| block( | |||
| self.inplanes, | |||
| planes, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm_layer=norm_layer, | |||
| )) | |||
| return nn.Sequential(*layers) | |||
| def _forward_impl(self, x: Tensor, grid: bool, level: List, both: bool, | |||
| grid_only: bool) -> Tensor: | |||
| # See note [TorchScript super()] | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| if grid: | |||
| x_grid = [] | |||
| if 3 in level: | |||
| x_grid.append(x.detach().clone()) | |||
| if not both and len(level) == 1: | |||
| return x_grid | |||
| x = self.layer4(x) | |||
| if 4 in level: | |||
| x_grid.append(x.detach().clone()) | |||
| if not both and len(level) == 1: | |||
| return x_grid | |||
| x = self.avgpool(x) | |||
| x = torch.flatten(x, 1) | |||
| if not grid or len(level) == 0: | |||
| return x | |||
| if grid_only: | |||
| return x_grid | |||
| if both: | |||
| return x, x_grid | |||
| return x | |||
| def forward( | |||
| self, | |||
| x: Tensor, | |||
| grid: bool = False, | |||
| level: List = [], | |||
| both: bool = False, | |||
| grid_only: bool = False, | |||
| ) -> Tensor: | |||
| return self._forward_impl(x, grid, level, both, grid_only) | |||
| def resnet50(**kwargs: Any) -> ResNet: | |||
| r"""ResNet-50 model from | |||
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. | |||
| """ | |||
| return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |||
| @@ -0,0 +1,132 @@ | |||
| # ------------------------------------------------------------------------------------ | |||
| # BaSSL | |||
| # Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |||
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |||
| # Github: https://github.com/kakaobrain/bassl | |||
| # ------------------------------------------------------------------------------------ | |||
| import torch | |||
| import torch.nn as nn | |||
| from transformers.models.bert.modeling_bert import BertEncoder | |||
| class ShotEmbedding(nn.Module): | |||
| def __init__(self, cfg): | |||
| super().__init__() | |||
| nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls | |||
| self.shot_embedding = nn.Linear(cfg.input_dim, cfg.hidden_size) | |||
| self.position_embedding = nn.Embedding(nn_size, cfg.hidden_size) | |||
| self.mask_embedding = nn.Embedding(2, cfg.input_dim, padding_idx=0) | |||
| # tf naming convention for layer norm | |||
| self.LayerNorm = nn.LayerNorm(cfg.hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(cfg.hidden_dropout_prob) | |||
| self.register_buffer('pos_ids', | |||
| torch.arange(nn_size, dtype=torch.long)) | |||
| def forward( | |||
| self, | |||
| shot_emb: torch.Tensor, | |||
| mask: torch.Tensor = None, | |||
| pos_ids: torch.Tensor = None, | |||
| ) -> torch.Tensor: | |||
| assert len(shot_emb.size()) == 3 | |||
| if pos_ids is None: | |||
| pos_ids = self.pos_ids | |||
| # this for mask embedding (un-masked ones remain unchanged) | |||
| if mask is not None: | |||
| self.mask_embedding.weight.data[0, :].fill_(0) | |||
| mask_emb = self.mask_embedding(mask.long()) | |||
| shot_emb = (shot_emb * (1 - mask).float()[:, :, None]) + mask_emb | |||
| # we set [CLS] token to averaged feature | |||
| cls_emb = shot_emb.mean(dim=1) | |||
| # embedding shots | |||
| shot_emb = torch.cat([cls_emb[:, None, :], shot_emb], dim=1) | |||
| shot_emb = self.shot_embedding(shot_emb) | |||
| pos_emb = self.position_embedding(pos_ids) | |||
| embeddings = shot_emb + pos_emb[None, :] | |||
| embeddings = self.dropout(self.LayerNorm(embeddings)) | |||
| return embeddings | |||
| class TransformerCRN(nn.Module): | |||
| def __init__(self, cfg): | |||
| super().__init__() | |||
| self.pooling_method = cfg.pooling_method | |||
| self.shot_embedding = ShotEmbedding(cfg) | |||
| self.encoder = BertEncoder(cfg) | |||
| nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls | |||
| self.register_buffer( | |||
| 'attention_mask', | |||
| self._get_extended_attention_mask( | |||
| torch.ones((1, nn_size)).float()), | |||
| ) | |||
| def forward( | |||
| self, | |||
| shot: torch.Tensor, | |||
| mask: torch.Tensor = None, | |||
| pos_ids: torch.Tensor = None, | |||
| pooling_method: str = None, | |||
| ): | |||
| if self.attention_mask.shape[1] != (shot.shape[1] + 1): | |||
| n_shot = shot.shape[1] + 1 # +1 for CLS token | |||
| attention_mask = self._get_extended_attention_mask( | |||
| torch.ones((1, n_shot), dtype=torch.float, device=shot.device)) | |||
| else: | |||
| attention_mask = self.attention_mask | |||
| shot_emb = self.shot_embedding(shot, mask=mask, pos_ids=pos_ids) | |||
| encoded_emb = self.encoder( | |||
| shot_emb, attention_mask=attention_mask).last_hidden_state | |||
| return encoded_emb, self.pooler( | |||
| encoded_emb, pooling_method=pooling_method) | |||
| def pooler(self, sequence_output, pooling_method=None): | |||
| if pooling_method is None: | |||
| pooling_method = self.pooling_method | |||
| if pooling_method == 'cls': | |||
| return sequence_output[:, 0, :] | |||
| elif pooling_method == 'avg': | |||
| return sequence_output[:, 1:].mean(dim=1) | |||
| elif pooling_method == 'max': | |||
| return sequence_output[:, 1:].max(dim=1)[0] | |||
| elif pooling_method == 'center': | |||
| cidx = sequence_output.shape[1] // 2 | |||
| return sequence_output[:, cidx, :] | |||
| else: | |||
| raise ValueError | |||
| def _get_extended_attention_mask(self, attention_mask): | |||
| # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
| # ourselves in which case we just need to make it broadcastable to all heads. | |||
| if attention_mask.dim() == 3: | |||
| extended_attention_mask = attention_mask[:, None, :, :] | |||
| elif attention_mask.dim() == 2: | |||
| extended_attention_mask = attention_mask[:, None, None, :] | |||
| else: | |||
| raise ValueError( | |||
| f'Wrong shape for attention_mask (shape {attention_mask.shape})' | |||
| ) | |||
| # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||
| # masked positions, this operation will create a tensor which is 0.0 for | |||
| # positions we want to attend and -10000.0 for masked positions. | |||
| # Since we are adding it to the raw scores before the softmax, this is | |||
| # effectively the same as removing these entirely. | |||
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
| return extended_attention_mask | |||
| @@ -9,7 +9,9 @@ if TYPE_CHECKING: | |||
| from .torch_base_dataset import TorchTaskDataset | |||
| from .veco_dataset import VecoDataset | |||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | |||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| else: | |||
| _import_structure = { | |||
| 'base': ['TaskDataset'], | |||
| @@ -19,6 +21,7 @@ else: | |||
| 'image_instance_segmentation_coco_dataset': | |||
| ['ImageInstanceSegmentationCocoDataset'], | |||
| 'video_summarization_dataset': ['VideoSummarizationDataset'], | |||
| 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1 @@ | |||
| from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset | |||
| @@ -0,0 +1,173 @@ | |||
| # --------------------------------------------------------------------------------------------------- | |||
| # The implementation is built upon BaSSL, publicly available at https://github.com/kakaobrain/bassl | |||
| # --------------------------------------------------------------------------------------------------- | |||
| import copy | |||
| import os | |||
| import os.path as osp | |||
| import random | |||
| import json | |||
| import torch | |||
| from torchvision.datasets.folder import pil_loader | |||
| from modelscope.metainfo import Models | |||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
| TorchTaskDataset | |||
| from modelscope.utils.constant import Tasks | |||
| from . import sampler | |||
| DATASET_STRUCTURE = { | |||
| 'train': { | |||
| 'annotation': 'anno/train.json', | |||
| 'images': 'keyf_240p', | |||
| 'feat': 'feat' | |||
| }, | |||
| 'test': { | |||
| 'annotation': 'anno/test.json', | |||
| 'images': 'keyf_240p', | |||
| 'feat': 'feat' | |||
| } | |||
| } | |||
| @TASK_DATASETS.register_module( | |||
| Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) | |||
| class MovieSceneSegmentationDataset(TorchTaskDataset): | |||
| """dataset for movie scene segmentation. | |||
| Args: | |||
| split_config (dict): Annotation file path. {"train":"xxxxx"} | |||
| data_root (str, optional): Data root for ``ann_file``, | |||
| ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. | |||
| test_mode (bool, optional): If set True, annotation will not be loaded. | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| split_config = kwargs['split_config'] | |||
| self.data_root = next(iter(split_config.values())) | |||
| if not osp.exists(self.data_root): | |||
| self.data_root = osp.dirname(self.data_root) | |||
| assert osp.exists(self.data_root) | |||
| self.split = next(iter(split_config.keys())) | |||
| self.preprocessor = kwargs['preprocessor'] | |||
| self.ann_file = osp.join(self.data_root, | |||
| DATASET_STRUCTURE[self.split]['annotation']) | |||
| self.img_prefix = osp.join(self.data_root, | |||
| DATASET_STRUCTURE[self.split]['images']) | |||
| self.feat_prefix = osp.join(self.data_root, | |||
| DATASET_STRUCTURE[self.split]['feat']) | |||
| self.test_mode = kwargs['test_mode'] | |||
| if self.test_mode: | |||
| self.preprocessor.eval() | |||
| else: | |||
| self.preprocessor.train() | |||
| self.cfg = kwargs.pop('cfg', None) | |||
| self.num_keyframe = self.cfg.num_keyframe if self.cfg is not None else 3 | |||
| self.use_single_keyframe = self.cfg.use_single_keyframe if self.cfg is not None else False | |||
| self.load_data() | |||
| self.init_sampler(self.cfg) | |||
| def __len__(self): | |||
| """Total number of samples of data.""" | |||
| return len(self.anno_data) | |||
| def __getitem__(self, idx: int): | |||
| data = self.anno_data[ | |||
| idx] # {"video_id", "shot_id", "num_shot", "boundary_label"} | |||
| vid, sid = data['video_id'], data['shot_id'] | |||
| num_shot = data['num_shot'] | |||
| shot_idx = self.shot_sampler(int(sid), num_shot) | |||
| video = self.load_shot_list(vid, shot_idx) | |||
| if self.preprocessor is None: | |||
| video = torch.stack(video, dim=0) | |||
| video = video.view(-1, self.num_keyframe, 3, 224, 224) | |||
| else: | |||
| video = self.preprocessor(video) | |||
| payload = { | |||
| 'idx': idx, | |||
| 'vid': vid, | |||
| 'sid': sid, | |||
| 'video': video, | |||
| 'label': abs(data['boundary_label']), # ignore -1 label. | |||
| } | |||
| return payload | |||
| def load_data(self): | |||
| self.tmpl = '{}/shot_{}_img_{}.jpg' # video_id, shot_id, shot_num | |||
| if not self.test_mode: | |||
| with open(self.ann_file) as f: | |||
| self.anno_data = json.load(f) | |||
| self.vidsid2label = { | |||
| f"{it['video_id']}_{it['shot_id']}": it['boundary_label'] | |||
| for it in self.anno_data | |||
| } | |||
| else: | |||
| with open(self.ann_file) as f: | |||
| self.anno_data = json.load(f) | |||
| def init_sampler(self, cfg): | |||
| # shot sampler | |||
| if cfg is not None: | |||
| self.sampling_method = cfg.sampling_method.name | |||
| sampler_args = copy.deepcopy( | |||
| cfg.sampling_method.params.get(self.sampling_method, {})) | |||
| if self.sampling_method == 'instance': | |||
| self.shot_sampler = sampler.InstanceShotSampler() | |||
| elif self.sampling_method == 'temporal': | |||
| self.shot_sampler = sampler.TemporalShotSampler(**sampler_args) | |||
| elif self.sampling_method == 'shotcol': | |||
| self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) | |||
| elif self.sampling_method == 'bassl': | |||
| self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) | |||
| elif self.sampling_method == 'bassl+shotcol': | |||
| self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) | |||
| elif self.sampling_method == 'sbd': | |||
| self.shot_sampler = sampler.NeighborShotSampler(**sampler_args) | |||
| else: | |||
| raise NotImplementedError | |||
| else: | |||
| self.shot_sampler = sampler.NeighborShotSampler() | |||
| def load_shot_list(self, vid, shot_idx): | |||
| shot_list = [] | |||
| cache = {} | |||
| for sidx in shot_idx: | |||
| vidsid = f'{vid}_{sidx:04d}' | |||
| if vidsid in cache: | |||
| shot = cache[vidsid] | |||
| else: | |||
| shot_path = os.path.join( | |||
| self.img_prefix, self.tmpl.format(vid, f'{sidx:04d}', | |||
| '{}')) | |||
| shot = self.load_shot_keyframes(shot_path) | |||
| cache[vidsid] = shot | |||
| shot_list.extend(shot) | |||
| return shot_list | |||
| def load_shot_keyframes(self, path): | |||
| shot = None | |||
| if not self.test_mode and self.use_single_keyframe: | |||
| # load one randomly sampled keyframe | |||
| shot = [ | |||
| pil_loader( | |||
| path.format(random.randint(0, self.num_keyframe - 1))) | |||
| ] | |||
| else: | |||
| # load all keyframes | |||
| shot = [ | |||
| pil_loader(path.format(i)) for i in range(self.num_keyframe) | |||
| ] | |||
| assert shot is not None | |||
| return shot | |||
| @@ -0,0 +1,102 @@ | |||
| # ------------------------------------------------------------------------------------ | |||
| # BaSSL | |||
| # Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |||
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |||
| # Github: https://github.com/kakaobrain/bassl | |||
| # ------------------------------------------------------------------------------------ | |||
| import random | |||
| import numpy as np | |||
| class InstanceShotSampler: | |||
| """ This is for instance at pre-training stage """ | |||
| def __call__(self, center_sid: int, *args, **kwargs): | |||
| return center_sid | |||
| class TemporalShotSampler: | |||
| """ This is for temporal at pre-training stage """ | |||
| def __init__(self, neighbor_size: int): | |||
| self.N = neighbor_size | |||
| def __call__(self, center_sid: int, total_num_shot: int): | |||
| """ we randomly sample one shot from neighbor shots within local temporal window | |||
| """ | |||
| shot_idx = center_sid + np.arange( | |||
| -self.N, self.N + 1 | |||
| ) # total number of neighbor shots = 2N+1 (query (1) + neighbors (2*N)) | |||
| shot_idx = np.clip(shot_idx, 0, | |||
| total_num_shot) # deal with out-of-boundary indices | |||
| shot_idx = random.choice( | |||
| np.unique(np.delete(shot_idx, np.where(shot_idx == center_sid)))) | |||
| return shot_idx | |||
| class SequenceShotSampler: | |||
| """ This is for bassl or shotcol at pre-training stage """ | |||
| def __init__(self, neighbor_size: int, neighbor_interval: int): | |||
| self.interval = neighbor_interval | |||
| self.window_size = neighbor_size * self.interval # temporal coverage | |||
| def __call__(self, | |||
| center_sid: int, | |||
| total_num_shot: int, | |||
| sparse_method: str = 'edge'): | |||
| """ | |||
| Args: | |||
| center_sid: index of center shot | |||
| total_num_shot: last index of shot for given video | |||
| sparse_stride: stride to sample sparse ones from dense sequence | |||
| for curriculum learning | |||
| """ | |||
| dense_shot_idx = center_sid + np.arange( | |||
| -self.window_size, self.window_size + 1, | |||
| self.interval) # total number of shots = 2*neighbor_size+1 | |||
| if dense_shot_idx[0] < 0: | |||
| # if center_sid is near left-side of video, we shift window rightward | |||
| # so that the leftmost index is 0 | |||
| dense_shot_idx -= dense_shot_idx[0] | |||
| elif dense_shot_idx[-1] > (total_num_shot - 1): | |||
| # if center_sid is near right-side of video, we shift window leftward | |||
| # so that the rightmost index is total_num_shot - 1 | |||
| dense_shot_idx -= dense_shot_idx[-1] - (total_num_shot - 1) | |||
| # to deal with videos that have smaller number of shots than window size | |||
| dense_shot_idx = np.clip(dense_shot_idx, 0, total_num_shot) | |||
| if sparse_method == 'edge': | |||
| # in this case, we use two edge shots as sparse sequence | |||
| sparse_stride = len(dense_shot_idx) - 1 | |||
| sparse_idx_to_dense = np.arange(0, len(dense_shot_idx), | |||
| sparse_stride) | |||
| elif sparse_method == 'edge+center': | |||
| # in this case, we use two edge shots + center shot as sparse sequence | |||
| sparse_idx_to_dense = np.array( | |||
| [0, len(dense_shot_idx) - 1, | |||
| len(dense_shot_idx) // 2]) | |||
| shot_idx = [sparse_idx_to_dense, dense_shot_idx] | |||
| return shot_idx | |||
| class NeighborShotSampler: | |||
| """ This is for scene boundary detection (sbd), i.e., fine-tuning stage """ | |||
| def __init__(self, neighbor_size: int = 8): | |||
| self.neighbor_size = neighbor_size | |||
| def __call__(self, center_sid: int, total_num_shot: int): | |||
| # total number of shots = 2 * neighbor_size + 1 | |||
| shot_idx = center_sid + np.arange(-self.neighbor_size, | |||
| self.neighbor_size + 1) | |||
| shot_idx = np.clip(shot_idx, 0, | |||
| total_num_shot) # for out-of-boundary indices | |||
| return shot_idx | |||
| @@ -35,6 +35,8 @@ class OutputKeys(object): | |||
| UUID = 'uuid' | |||
| WORD = 'word' | |||
| KWS_LIST = 'kws_list' | |||
| SPLIT_VIDEO_NUM = 'split_video_num' | |||
| SPLIT_META_DICT = 'split_meta_dict' | |||
| TASK_OUTPUTS = { | |||
| @@ -241,6 +243,22 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], | |||
| # movide scene segmentation result for a single video | |||
| # { | |||
| # "split_video_num":3, | |||
| # "split_meta_dict": | |||
| # { | |||
| # scene_id: | |||
| # { | |||
| # "shot": [0,1,2], | |||
| # "frame": [start_frame, end_frame] | |||
| # } | |||
| # } | |||
| # | |||
| # } | |||
| Tasks.movie_scene_segmentation: | |||
| [OutputKeys.SPLIT_VIDEO_NUM, OutputKeys.SPLIT_META_DICT], | |||
| # ============ nlp tasks =================== | |||
| # text classification result for single sample | |||
| @@ -144,6 +144,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_vitb_video-single-object-tracking_ostrack'), | |||
| Tasks.image_reid_person: (Pipelines.image_reid_person, | |||
| 'damo/cv_passvitb_image-reid-person_market'), | |||
| Tasks.movie_scene_segmentation: | |||
| (Pipelines.movie_scene_segmentation, | |||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | |||
| } | |||
| @@ -42,6 +42,8 @@ if TYPE_CHECKING: | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | |||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | |||
| else: | |||
| _import_structure = { | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| @@ -90,7 +92,9 @@ else: | |||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | |||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
| 'easycv_pipeline': | |||
| ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'] | |||
| ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], | |||
| 'movie_scene_segmentation_pipeline': | |||
| ['MovieSceneSegmentationPipeline'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,67 @@ | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.movie_scene_segmentation, | |||
| module_name=Pipelines.movie_scene_segmentation) | |||
| class MovieSceneSegmentationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """use `model` to create a movie scene segmentation pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub | |||
| """ | |||
| _device = kwargs.pop('device', 'gpu') | |||
| if torch.cuda.is_available() and _device == 'gpu': | |||
| device = 'gpu' | |||
| else: | |||
| device = 'cpu' | |||
| super().__init__(model=model, device=device, **kwargs) | |||
| logger.info('Load model done!') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| """ use pyscenedetect to detect shot from the input video, and generate key-frame jpg, anno.ndjson, and shot-frame.txt | |||
| Then use shot-encoder to encoder feat of the detected key-frame | |||
| Args: | |||
| input: path of the input video | |||
| """ | |||
| self.input_video_pth = input | |||
| if isinstance(input, str): | |||
| shot_feat, sid = self.model.preprocess(input) | |||
| else: | |||
| raise TypeError(f'input should be a str,' | |||
| f' but got {type(input)}') | |||
| result = {'sid': sid, 'shot_feat': shot_feat} | |||
| return result | |||
| def forward(self, input: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| output = self.model.inference(input) | |||
| return output | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| data = {'input_video_pth': self.input_video_pth, 'feat': inputs} | |||
| video_num, meta_dict = self.model.postprocess(data) | |||
| result = { | |||
| OutputKeys.SPLIT_VIDEO_NUM: video_num, | |||
| OutputKeys.SPLIT_META_DICT: meta_dict | |||
| } | |||
| return result | |||
| @@ -27,7 +27,7 @@ if TYPE_CHECKING: | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| from .video import ReadVideoData | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| from .star import ConversationalTextToSqlPreprocessor | |||
| else: | |||
| @@ -37,7 +37,7 @@ else: | |||
| 'common': ['Compose', 'ToTensor', 'Filter'], | |||
| 'audio': ['LinearAECAndFbank'], | |||
| 'asr': ['WavToScp'], | |||
| 'video': ['ReadVideoData'], | |||
| 'video': ['ReadVideoData', 'MovieSceneSegmentationPreprocessor'], | |||
| 'image': [ | |||
| 'LoadImage', 'load_image', 'ImageColorEnhanceFinetunePreprocessor', | |||
| 'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' | |||
| @@ -0,0 +1,19 @@ | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .transforms import get_transform | |||
| else: | |||
| _import_structure = { | |||
| 'transforms': ['get_transform'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,312 @@ | |||
| # ------------------------------------------------------------------------------------ | |||
| # The codes below partially refer to the BaSSL | |||
| # Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |||
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |||
| # Github: https://github.com/kakaobrain/bassl | |||
| # ------------------------------------------------------------------------------------ | |||
| import numbers | |||
| import os.path as osp | |||
| import random | |||
| from typing import List | |||
| import numpy as np | |||
| import torch | |||
| import torchvision.transforms as TF | |||
| import torchvision.transforms.functional as F | |||
| from PIL import Image, ImageFilter | |||
| def get_transform(lst): | |||
| assert len(lst) > 0 | |||
| transform_lst = [] | |||
| for item in lst: | |||
| transform_lst.append(build_transform(item)) | |||
| transform = TF.Compose(transform_lst) | |||
| return transform | |||
| def build_transform(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'VideoResizedCenterCrop': | |||
| return VideoResizedCenterCrop(**cfg) | |||
| elif type == 'VideoToTensor': | |||
| return VideoToTensor(**cfg) | |||
| elif type == 'VideoRandomResizedCrop': | |||
| return VideoRandomResizedCrop(**cfg) | |||
| elif type == 'VideoRandomHFlip': | |||
| return VideoRandomHFlip() | |||
| elif type == 'VideoRandomColorJitter': | |||
| return VideoRandomColorJitter(**cfg) | |||
| elif type == 'VideoRandomGaussianBlur': | |||
| return VideoRandomGaussianBlur(**cfg) | |||
| else: | |||
| raise NotImplementedError | |||
| class VideoResizedCenterCrop(torch.nn.Module): | |||
| def __init__(self, image_size, crop_size): | |||
| self.tfm = TF.Compose([ | |||
| TF.Resize(size=image_size, interpolation=Image.BICUBIC), | |||
| TF.CenterCrop(crop_size), | |||
| ]) | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| return [self.tfm(img) for img in imgmap] | |||
| class VideoToTensor(torch.nn.Module): | |||
| def __init__(self, mean=None, std=None, inplace=False): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.inplace = inplace | |||
| assert self.mean is not None | |||
| assert self.std is not None | |||
| def __to_tensor__(self, img): | |||
| return F.to_tensor(img) | |||
| def __normalize__(self, img): | |||
| return F.normalize(img, self.mean, self.std, self.inplace) | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| return [self.__normalize__(self.__to_tensor__(img)) for img in imgmap] | |||
| class VideoRandomResizedCrop(torch.nn.Module): | |||
| def __init__(self, size, bottom_area=0.2): | |||
| self.p = 1.0 | |||
| self.interpolation = Image.BICUBIC | |||
| self.size = size | |||
| self.bottom_area = bottom_area | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| if random.random() < self.p: # do RandomResizedCrop, consistent=True | |||
| top, left, height, width = TF.RandomResizedCrop.get_params( | |||
| imgmap[0], | |||
| scale=(self.bottom_area, 1.0), | |||
| ratio=(3 / 4.0, 4 / 3.0)) | |||
| return [ | |||
| F.resized_crop( | |||
| img=img, | |||
| top=top, | |||
| left=left, | |||
| height=height, | |||
| width=width, | |||
| size=(self.size, self.size), | |||
| ) for img in imgmap | |||
| ] | |||
| else: | |||
| return [ | |||
| F.resize(img=img, size=[self.size, self.size]) | |||
| for img in imgmap | |||
| ] | |||
| class VideoRandomHFlip(torch.nn.Module): | |||
| def __init__(self, consistent=True, command=None, seq_len=0): | |||
| self.consistent = consistent | |||
| if seq_len != 0: | |||
| self.consistent = False | |||
| if command == 'left': | |||
| self.threshold = 0 | |||
| elif command == 'right': | |||
| self.threshold = 1 | |||
| else: | |||
| self.threshold = 0.5 | |||
| self.seq_len = seq_len | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| if self.consistent: | |||
| if random.random() < self.threshold: | |||
| return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] | |||
| else: | |||
| return imgmap | |||
| else: | |||
| result = [] | |||
| for idx, i in enumerate(imgmap): | |||
| if idx % self.seq_len == 0: | |||
| th = random.random() | |||
| if th < self.threshold: | |||
| result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) | |||
| else: | |||
| result.append(i) | |||
| assert len(result) == len(imgmap) | |||
| return result | |||
| class VideoRandomColorJitter(torch.nn.Module): | |||
| """Randomly change the brightness, contrast and saturation of an image. | |||
| Args: | |||
| brightness (float or tuple of float (min, max)): How much to jitter brightness. | |||
| brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] | |||
| or the given [min, max]. Should be non negative numbers. | |||
| contrast (float or tuple of float (min, max)): How much to jitter contrast. | |||
| contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] | |||
| or the given [min, max]. Should be non negative numbers. | |||
| saturation (float or tuple of float (min, max)): How much to jitter saturation. | |||
| saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] | |||
| or the given [min, max]. Should be non negative numbers. | |||
| hue (float or tuple of float (min, max)): How much to jitter hue. | |||
| hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. | |||
| Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| brightness=0, | |||
| contrast=0, | |||
| saturation=0, | |||
| hue=0, | |||
| consistent=True, | |||
| p=1.0, | |||
| seq_len=0, | |||
| ): | |||
| self.brightness = self._check_input(brightness, 'brightness') | |||
| self.contrast = self._check_input(contrast, 'contrast') | |||
| self.saturation = self._check_input(saturation, 'saturation') | |||
| self.hue = self._check_input( | |||
| hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) | |||
| self.consistent = consistent | |||
| self.threshold = p | |||
| self.seq_len = seq_len | |||
| def _check_input(self, | |||
| value, | |||
| name, | |||
| center=1, | |||
| bound=(0, float('inf')), | |||
| clip_first_on_zero=True): | |||
| if isinstance(value, numbers.Number): | |||
| if value < 0: | |||
| raise ValueError( | |||
| 'If {} is a single number, it must be non negative.'. | |||
| format(name)) | |||
| value = [center - value, center + value] | |||
| if clip_first_on_zero: | |||
| value[0] = max(value[0], 0) | |||
| elif isinstance(value, (tuple, list)) and len(value) == 2: | |||
| if not bound[0] <= value[0] <= value[1] <= bound[1]: | |||
| raise ValueError('{} values should be between {}'.format( | |||
| name, bound)) | |||
| else: | |||
| raise TypeError( | |||
| '{} should be a single number or a list/tuple with lenght 2.'. | |||
| format(name)) | |||
| # if value is 0 or (1., 1.) for brightness/contrast/saturation | |||
| # or (0., 0.) for hue, do nothing | |||
| if value[0] == value[1] == center: | |||
| value = None | |||
| return value | |||
| @staticmethod | |||
| def get_params(brightness, contrast, saturation, hue): | |||
| """Get a randomized transform to be applied on image. | |||
| Arguments are same as that of __init__. | |||
| Returns: | |||
| Transform which randomly adjusts brightness, contrast and | |||
| saturation in a random order. | |||
| """ | |||
| transforms = [] | |||
| if brightness is not None: | |||
| brightness_factor = random.uniform(brightness[0], brightness[1]) | |||
| transforms.append( | |||
| TF.Lambda( | |||
| lambda img: F.adjust_brightness(img, brightness_factor))) | |||
| if contrast is not None: | |||
| contrast_factor = random.uniform(contrast[0], contrast[1]) | |||
| transforms.append( | |||
| TF.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) | |||
| if saturation is not None: | |||
| saturation_factor = random.uniform(saturation[0], saturation[1]) | |||
| transforms.append( | |||
| TF.Lambda( | |||
| lambda img: F.adjust_saturation(img, saturation_factor))) | |||
| if hue is not None: | |||
| hue_factor = random.uniform(hue[0], hue[1]) | |||
| transforms.append( | |||
| TF.Lambda(lambda img: F.adjust_hue(img, hue_factor))) | |||
| random.shuffle(transforms) | |||
| transform = TF.Compose(transforms) | |||
| return transform | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| if random.random() < self.threshold: # do ColorJitter | |||
| if self.consistent: | |||
| transform = self.get_params(self.brightness, self.contrast, | |||
| self.saturation, self.hue) | |||
| return [transform(i) for i in imgmap] | |||
| else: | |||
| if self.seq_len == 0: | |||
| return [ | |||
| self.get_params(self.brightness, self.contrast, | |||
| self.saturation, self.hue)(img) | |||
| for img in imgmap | |||
| ] | |||
| else: | |||
| result = [] | |||
| for idx, img in enumerate(imgmap): | |||
| if idx % self.seq_len == 0: | |||
| transform = self.get_params( | |||
| self.brightness, | |||
| self.contrast, | |||
| self.saturation, | |||
| self.hue, | |||
| ) | |||
| result.append(transform(img)) | |||
| return result | |||
| else: | |||
| return imgmap | |||
| def __repr__(self): | |||
| format_string = self.__class__.__name__ + '(' | |||
| format_string += 'brightness={0}'.format(self.brightness) | |||
| format_string += ', contrast={0}'.format(self.contrast) | |||
| format_string += ', saturation={0}'.format(self.saturation) | |||
| format_string += ', hue={0})'.format(self.hue) | |||
| return format_string | |||
| class VideoRandomGaussianBlur(torch.nn.Module): | |||
| def __init__(self, radius_min=0.1, radius_max=2.0, p=0.5): | |||
| self.radius_min = radius_min | |||
| self.radius_max = radius_max | |||
| self.p = p | |||
| def __call__(self, imgmap): | |||
| assert isinstance(imgmap, list) | |||
| if random.random() < self.p: | |||
| result = [] | |||
| for _, img in enumerate(imgmap): | |||
| _radius = random.uniform(self.radius_min, self.radius_max) | |||
| result.append( | |||
| img.filter(ImageFilter.GaussianBlur(radius=_radius))) | |||
| return result | |||
| else: | |||
| return imgmap | |||
| def apply_transform(images, trans): | |||
| return torch.stack(trans(images), dim=0) | |||
| @@ -9,6 +9,12 @@ import torchvision.transforms._transforms_video as transforms | |||
| from decord import VideoReader | |||
| from torchvision.transforms import Compose | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, ModeKeys | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| def ReadVideoData(cfg, video_path): | |||
| """ simple interface to load video frames from file | |||
| @@ -227,3 +233,42 @@ class KineticsResizedCrop(object): | |||
| def __call__(self, clip): | |||
| return self._get_controlled_crop(clip) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor) | |||
| class MovieSceneSegmentationPreprocessor(Preprocessor): | |||
| def __init__(self, *args, **kwargs): | |||
| """ | |||
| movie scene segmentation preprocessor | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.is_train = kwargs.pop('is_train', True) | |||
| self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None) | |||
| self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None) | |||
| self.num_keyframe = kwargs.pop('num_keyframe', 3) | |||
| from .movie_scene_segmentation import get_transform | |||
| self.train_transform = get_transform(self.preprocessor_train_cfg) | |||
| self.test_transform = get_transform(self.preprocessor_test_cfg) | |||
| def train(self): | |||
| self.is_train = True | |||
| return | |||
| def eval(self): | |||
| self.is_train = False | |||
| return | |||
| @type_assert(object, object) | |||
| def __call__(self, results): | |||
| if self.is_train: | |||
| transforms = self.train_transform | |||
| else: | |||
| transforms = self.test_transform | |||
| results = torch.stack(transforms(results), dim=0) | |||
| results = results.view(-1, self.num_keyframe, 3, 224, 224) | |||
| return results | |||
| @@ -8,7 +8,8 @@ if TYPE_CHECKING: | |||
| from .base import DummyTrainer | |||
| from .builder import build_trainer | |||
| from .cv import (ImageInstanceSegmentationTrainer, | |||
| ImagePortraitEnhancementTrainer) | |||
| ImagePortraitEnhancementTrainer, | |||
| MovieSceneSegmentationTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer | |||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | |||
| @@ -21,7 +22,7 @@ else: | |||
| 'builder': ['build_trainer'], | |||
| 'cv': [ | |||
| 'ImageInstanceSegmentationTrainer', | |||
| 'ImagePortraitEnhancementTrainer' | |||
| 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | |||
| ], | |||
| 'multi_modal': ['CLIPTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer'], | |||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||
| from .image_instance_segmentation_trainer import \ | |||
| ImageInstanceSegmentationTrainer | |||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||
| from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | |||
| else: | |||
| _import_structure = { | |||
| @@ -14,6 +15,7 @@ else: | |||
| ['ImageInstanceSegmentationTrainer'], | |||
| 'image_portrait_enhancement_trainer': | |||
| ['ImagePortraitEnhancementTrainer'], | |||
| 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,20 @@ | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||
| @TRAINERS.register_module(module_name=Trainers.movie_scene_segmentation) | |||
| class MovieSceneSegmentationTrainer(EpochBasedTrainer): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| def train(self, *args, **kwargs): | |||
| super().train(*args, **kwargs) | |||
| def evaluate(self, *args, **kwargs): | |||
| metric_values = super().evaluate(*args, **kwargs) | |||
| return metric_values | |||
| def prediction_step(self, model, inputs): | |||
| pass | |||
| @@ -62,6 +62,7 @@ class CVTasks(object): | |||
| video_embedding = 'video-embedding' | |||
| virtual_try_on = 'virtual-try-on' | |||
| crowd_counting = 'crowd-counting' | |||
| movie_scene_segmentation = 'movie-scene-segmentation' | |||
| # reid and tracking | |||
| video_single_object_tracking = 'video-single-object-tracking' | |||
| @@ -21,6 +21,7 @@ regex | |||
| scikit-image>=0.19.3 | |||
| scikit-learn>=0.20.1 | |||
| shapely | |||
| shotdetect_scenedetect_lgss | |||
| tensorflow-estimator>=1.15.1 | |||
| tf_slim | |||
| timm>=0.4.9 | |||
| @@ -31,6 +31,12 @@ class ImgPreprocessor(Preprocessor): | |||
| class MsDatasetTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_movie_scene_seg_toydata(self): | |||
| ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train') | |||
| print(ms_ds_train._hf_ds.config_kwargs) | |||
| assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_coco(self): | |||
| ms_ds_train = MsDataset.load( | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class MovieSceneSegmentationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_movie_scene_segmentation(self): | |||
| input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' | |||
| model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' | |||
| movie_scene_segmentation_pipeline = pipeline( | |||
| Tasks.movie_scene_segmentation, model=model_id) | |||
| result = movie_scene_segmentation_pipeline(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_movie_scene_segmentation_with_default_task(self): | |||
| input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' | |||
| movie_scene_segmentation_pipeline = pipeline( | |||
| Tasks.movie_scene_segmentation) | |||
| result = movie_scene_segmentation_pipeline(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.cv.movie_scene_segmentation import \ | |||
| MovieSceneSegmentationModel | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||
| model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| cache_path = snapshot_download(self.model_id) | |||
| config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| max_epochs = cfg.train.max_epochs | |||
| train_data_cfg = ConfigDict( | |||
| name='movie_scene_seg_toydata', | |||
| split='train', | |||
| cfg=cfg.preprocessor, | |||
| test_mode=False) | |||
| test_data_cfg = ConfigDict( | |||
| name='movie_scene_seg_toydata', | |||
| split='test', | |||
| cfg=cfg.preprocessor, | |||
| test_mode=True) | |||
| self.train_dataset = MsDataset.load( | |||
| dataset_name=train_data_cfg.name, | |||
| split=train_data_cfg.split, | |||
| namespace=train_data_cfg.namespace, | |||
| cfg=train_data_cfg.cfg, | |||
| test_mode=train_data_cfg.test_mode) | |||
| assert next( | |||
| iter(self.train_dataset.config_kwargs['split_config'].values())) | |||
| self.test_dataset = MsDataset.load( | |||
| dataset_name=test_data_cfg.name, | |||
| split=test_data_cfg.split, | |||
| namespace=test_data_cfg.namespace, | |||
| cfg=test_data_cfg.cfg, | |||
| test_mode=test_data_cfg.test_mode) | |||
| assert next( | |||
| iter(self.test_dataset.config_kwargs['split_config'].values())) | |||
| self.max_epochs = max_epochs | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer( | |||
| name=Trainers.movie_scene_segmentation, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(trainer.work_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = MovieSceneSegmentationModel.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| work_dir=tmp_dir) | |||
| trainer = build_trainer( | |||
| name=Trainers.movie_scene_segmentation, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(trainer.work_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||