From c8b6030b8e0fc8de10e16a38412409ba67ef6bf4 Mon Sep 17 00:00:00 2001 From: "yongfei.zyf" Date: Thu, 1 Sep 2022 14:20:04 +0800 Subject: [PATCH] [to #42322933] Add hicossl_video_embedding_pipeline to maas lib Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9969472 --- modelscope/metainfo.py | 1 + .../models/cv/action_recognition/models.py | 45 ++- .../models/cv/action_recognition/s3dg.py | 301 ++++++++++++++++++ modelscope/pipelines/cv/__init__.py | 2 + .../cv/action_recognition_pipeline.py | 1 + .../cv/hicossl_video_embedding_pipeline.py | 75 +++++ modelscope/preprocessors/video.py | 119 +++++-- .../pipelines/test_hicossl_video_embedding.py | 26 ++ 8 files changed, 538 insertions(+), 32 deletions(-) create mode 100644 modelscope/models/cv/action_recognition/s3dg.py create mode 100644 modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py create mode 100644 tests/pipelines/test_hicossl_video_embedding.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 4bb0857b..51fed99f 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -99,6 +99,7 @@ class Pipelines(object): animal_recognition = 'resnet101-animal-recognition' general_recognition = 'resnet101-general-recognition' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + hicossl_video_embedding = 'hicossl-s3dg-video_embedding' body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' body_3d_keypoints = 'canonical_body-3d-keypoints_video' human_detection = 'resnet18-human-detection' diff --git a/modelscope/models/cv/action_recognition/models.py b/modelscope/models/cv/action_recognition/models.py index 48e75ae1..a5964e21 100644 --- a/modelscope/models/cv/action_recognition/models.py +++ b/modelscope/models/cv/action_recognition/models.py @@ -1,5 +1,6 @@ import torch.nn as nn +from .s3dg import Inception3D from .tada_convnext import TadaConvNeXt @@ -26,11 +27,25 @@ class BaseVideoModel(nn.Module): super(BaseVideoModel, self).__init__() # the backbone is created according to meta-architectures # defined in models/base/backbone.py - self.backbone = TadaConvNeXt(cfg) + if cfg.MODEL.NAME == 'ConvNeXt_tiny': + self.backbone = TadaConvNeXt(cfg) + elif cfg.MODEL.NAME == 'S3DG': + self.backbone = Inception3D(cfg) + else: + error_str = 'backbone {} is not supported, ConvNeXt_tiny or S3DG is supported'.format( + cfg.MODEL.NAME) + raise NotImplementedError(error_str) # the head is created according to the heads # defined in models/module_zoo/heads - self.head = BaseHead(cfg) + if cfg.VIDEO.HEAD.NAME == 'BaseHead': + self.head = BaseHead(cfg) + elif cfg.VIDEO.HEAD.NAME == 'AvgHead': + self.head = AvgHead(cfg) + else: + error_str = 'head {} is not supported, BaseHead or AvgHead is supported'.format( + cfg.VIDEO.HEAD.NAME) + raise NotImplementedError(error_str) def forward(self, x): x = self.backbone(x) @@ -88,3 +103,29 @@ class BaseHead(nn.Module): out = self.activation(out) out = out.view(out.shape[0], -1) return out, x.view(x.shape[0], -1) + + +class AvgHead(nn.Module): + """ + Constructs base head. + """ + + def __init__( + self, + cfg, + ): + """ + Args: + cfg (Config): global config object. + """ + super(AvgHead, self).__init__() + self.cfg = cfg + self.global_avg_pool = nn.AdaptiveAvgPool3d(1) + + def forward(self, x): + if len(x.shape) == 5: + x = self.global_avg_pool(x) + # (N, C, T, H, W) -> (N, T, H, W, C). + x = x.permute((0, 2, 3, 4, 1)) + out = x.view(x.shape[0], -1) + return out, x.view(x.shape[0], -1) diff --git a/modelscope/models/cv/action_recognition/s3dg.py b/modelscope/models/cv/action_recognition/s3dg.py new file mode 100644 index 00000000..f258df16 --- /dev/null +++ b/modelscope/models/cv/action_recognition/s3dg.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn + + +class InceptionBaseConv3D(nn.Module): + """ + Constructs basic inception 3D conv. + Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + padding=0): + super(InceptionBaseConv3D, self).__init__() + self.conv = nn.Conv3d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False) + self.bn = nn.BatchNorm3d(out_planes) + self.relu = nn.ReLU(inplace=True) + + # init + self.conv.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.bn.weight.data.fill_(1) + self.bn.bias.data.zero_() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class InceptionBlock3D(nn.Module): + """ + Element constructing the S3D/S3DG. + See models/base/backbone.py L99-186. + + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, cfg, in_planes, out_planes): + super(InceptionBlock3D, self).__init__() + + _gating = cfg.VIDEO.BACKBONE.BRANCH.GATING + + assert len(out_planes) == 6 + assert isinstance(out_planes, list) + + [ + num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a, + num_out_2_0b, num_out_3_0b + ] = out_planes + + self.branch0 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), ) + self.branch1 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1), + STConv3d( + cfg, + num_out_1_0a, + num_out_1_0b, + kernel_size=3, + stride=1, + padding=1), + ) + self.branch2 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1), + STConv3d( + cfg, + num_out_2_0a, + num_out_2_0b, + kernel_size=3, + stride=1, + padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), + InceptionBaseConv3D( + cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1), + ) + + self.out_channels = sum( + [num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) + + self.gating = _gating + if _gating: + self.gating_b0 = SelfGating(num_out_0_0a) + self.gating_b1 = SelfGating(num_out_1_0b) + self.gating_b2 = SelfGating(num_out_2_0b) + self.gating_b3 = SelfGating(num_out_3_0b) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + if self.gating: + x0 = self.gating_b0(x0) + x1 = self.gating_b1(x1) + x2 = self.gating_b2(x2) + x3 = self.gating_b3(x3) + + out = torch.cat((x0, x1, x2, x3), 1) + + return out + + +class SelfGating(nn.Module): + + def __init__(self, input_dim): + super(SelfGating, self).__init__() + self.fc = nn.Linear(input_dim, input_dim) + + def forward(self, input_tensor): + """Feature gating as used in S3D-G""" + spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) + weights = self.fc(spatiotemporal_average) + weights = torch.sigmoid(weights) + return weights[:, :, None, None, None] * input_tensor + + +class STConv3d(nn.Module): + """ + Element constructing the S3D/S3DG. + See models/base/backbone.py L99-186. + + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + padding=0): + super(STConv3d, self).__init__() + if isinstance(stride, tuple): + t_stride = stride[0] + stride = stride[-1] + else: # int + t_stride = stride + + self.bn_mmt = cfg.BN.MOMENTUM + self.bn_eps = float(cfg.BN.EPS) + self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride, + t_stride, padding) + + def _construct_branch(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + t_stride, + padding=0): + self.conv1 = nn.Conv3d( + in_planes, + out_planes, + kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False) + self.conv2 = nn.Conv3d( + out_planes, + out_planes, + kernel_size=(kernel_size, 1, 1), + stride=(t_stride, 1, 1), + padding=(padding, 0, 0), + bias=False) + + self.bn1 = nn.BatchNorm3d( + out_planes, eps=self.bn_eps, momentum=self.bn_mmt) + self.bn2 = nn.BatchNorm3d( + out_planes, eps=self.bn_eps, momentum=self.bn_mmt) + self.relu = nn.ReLU(inplace=True) + + # init + self.conv1.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.conv2.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.bn1.weight.data.fill_(1) + self.bn1.bias.data.zero_() + self.bn2.weight.data.fill_(1) + self.bn2.bias.data.zero_() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + return x + + +class Inception3D(nn.Module): + """ + Backbone architecture for I3D/S3DG. + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, cfg): + """ + Args: + cfg (Config): global config object. + """ + super(Inception3D, self).__init__() + _input_channel = cfg.DATA.NUM_INPUT_CHANNELS + self._construct_backbone(cfg, _input_channel) + + def _construct_backbone(self, cfg, input_channel): + # ------------------- Block 1 ------------------- + self.Conv_1a = STConv3d( + cfg, input_channel, 64, kernel_size=7, stride=2, padding=3) + + self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) + + # ------------------- Block 2 ------------------- + self.MaxPool_2a = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.Conv_2b = InceptionBaseConv3D( + cfg, 64, 64, kernel_size=1, stride=1) + self.Conv_2c = STConv3d( + cfg, 64, 192, kernel_size=3, stride=1, padding=1) + + self.block2 = nn.Sequential( + self.MaxPool_2a, # (64, 32, 56, 56) + self.Conv_2b, # (64, 32, 56, 56) + self.Conv_2c) # (192, 32, 56, 56) + + # ------------------- Block 3 ------------------- + self.MaxPool_3a = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.Mixed_3b = InceptionBlock3D( + cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32]) + self.Mixed_3c = InceptionBlock3D( + cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64]) + + self.block3 = nn.Sequential( + self.MaxPool_3a, # (192, 32, 28, 28) + self.Mixed_3b, # (256, 32, 28, 28) + self.Mixed_3c) # (480, 32, 28, 28) + + # ------------------- Block 4 ------------------- + self.MaxPool_4a = nn.MaxPool3d( + kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) + self.Mixed_4b = InceptionBlock3D( + cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64]) + self.Mixed_4c = InceptionBlock3D( + cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64]) + self.Mixed_4d = InceptionBlock3D( + cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64]) + self.Mixed_4e = InceptionBlock3D( + cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64]) + self.Mixed_4f = InceptionBlock3D( + cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128]) + + self.block4 = nn.Sequential( + self.MaxPool_4a, # (480, 16, 14, 14) + self.Mixed_4b, # (512, 16, 14, 14) + self.Mixed_4c, # (512, 16, 14, 14) + self.Mixed_4d, # (512, 16, 14, 14) + self.Mixed_4e, # (528, 16, 14, 14) + self.Mixed_4f) # (832, 16, 14, 14) + + # ------------------- Block 5 ------------------- + self.MaxPool_5a = nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) + self.Mixed_5b = InceptionBlock3D( + cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128]) + self.Mixed_5c = InceptionBlock3D( + cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128]) + + self.block5 = nn.Sequential( + self.MaxPool_5a, # (832, 8, 7, 7) + self.Mixed_5b, # (832, 8, 7, 7) + self.Mixed_5c) # (1024, 8, 7, 7) + + def forward(self, x): + if isinstance(x, dict): + x = x['video'] + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + return x diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index bd175578..01c69758 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline + from .hicossl_video_embedding_pipeline import HICOSSLVideoEmbeddingPipeline from .crowd_counting_pipeline import CrowdCountingPipeline from .image_detection_pipeline import ImageDetectionPipeline from .image_salient_detection_pipeline import ImageSalientDetectionPipeline @@ -51,6 +52,7 @@ else: 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], + 'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'], 'crowd_counting_pipeline': ['CrowdCountingPipeline'], 'image_detection_pipeline': ['ImageDetectionPipeline'], 'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'], diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index 087548f0..e3400ea7 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -33,6 +33,7 @@ class ActionRecognitionPipeline(Pipeline): config_path = osp.join(self.model, ModelFile.CONFIGURATION) logger.info(f'loading config from {config_path}') self.cfg = Config.from_file(config_path) + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) self.infer_model.eval() self.infer_model.load_state_dict( diff --git a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py new file mode 100644 index 00000000..5e4cd4c6 --- /dev/null +++ b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py @@ -0,0 +1,75 @@ +import math +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_recognition import BaseVideoModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ReadVideoData +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_embedding, module_name=Pipelines.hicossl_video_embedding) +class HICOSSLVideoEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a hicossl video embedding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)['model_state'], + strict=False) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData( + self.cfg, input, num_temporal_views_override=1).to(self.device) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + feature = self.perform_inference(input['video_data']) + return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + """ Perform feature extracting for a given video + Args: + model (BaseVideoModel): video model with loadded state dict. + max_bsz (int): the maximum batch size, limited by GPU memory. + Returns: + pred (Tensor): the extracted features for input video clips. + """ + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) + pred = torch.cat(preds_list, dim=0) + return pred + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/video.py b/modelscope/preprocessors/video.py index 0d2e8c3e..f693cd9e 100644 --- a/modelscope/preprocessors/video.py +++ b/modelscope/preprocessors/video.py @@ -16,34 +16,49 @@ from .base import Preprocessor from .builder import PREPROCESSORS -def ReadVideoData(cfg, video_path): +def ReadVideoData(cfg, + video_path, + num_spatial_crops_override=None, + num_temporal_views_override=None): """ simple interface to load video frames from file Args: cfg (Config): The global config object. video_path (str): video file path + num_spatial_crops_override (int): the spatial crops per clip + num_temporal_views_override (int): the temporal clips per video + Returns: + data (Tensor): the normalized video clips for model inputs """ - data = _decode_video(cfg, video_path) - transform = kinetics400_tranform(cfg) + data = _decode_video(cfg, video_path, num_temporal_views_override) + if num_spatial_crops_override is not None: + num_spatial_crops = num_spatial_crops_override + transform = kinetics400_tranform(cfg, num_spatial_crops_override) + else: + num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS + transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS) data_list = [] for i in range(data.size(0)): - for j in range(cfg.TEST.NUM_SPATIAL_CROPS): + for j in range(num_spatial_crops): transform.transforms[1].set_spatial_index(j) data_list.append(transform(data[i])) return torch.stack(data_list, dim=0) -def kinetics400_tranform(cfg): +def kinetics400_tranform(cfg, num_spatial_crops): """ Configs the transform for the kinetics-400 dataset. We apply controlled spatial cropping and normalization. Args: cfg (Config): The global config object. + num_spatial_crops (int): the spatial crops per clip + Returns: + transform_function (Compose): the transform function for input clips """ resize_video = KineticsResizedCrop( short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], crop_size=cfg.DATA.TEST_CROP_SIZE, - num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS) + num_spatial_crops=num_spatial_crops) std_transform_list = [ transforms.ToTensorVideo(), resize_video, transforms.NormalizeVideo( @@ -60,17 +75,17 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, vid_length (int): the length of the whole video (valid selection range). vid_fps (int): the original video fps target_fps (int): the normalized video fps - clip_idx (int): -1 for random temporal sampling, and positive values for - sampling specific clip from the video + clip_idx (int): -1 for random temporal sampling, and positive values for sampling specific + clip from the video num_clips (int): the total clips to be sampled from each video. - combined with clip_idx, the sampled video is the "clip_idx-th" - video from "num_clips" videos. + combined with clip_idx, the sampled video is the "clip_idx-th" video from + "num_clips" videos. num_frames (int): number of frames in each sampled clips. interval (int): the interval to sample each frame. minus_interval (bool): control the end index Returns: index (tensor): the sampled frame indexes - """ + """ if num_frames == 1: index = [random.randint(0, vid_length - 1)] else: @@ -78,7 +93,10 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, clip_length = num_frames * interval * vid_fps / target_fps max_idx = max(vid_length - clip_length, 0) - start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) + if num_clips == 1: + start_idx = max_idx / 2 + else: + start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) if minus_interval: end_idx = start_idx + clip_length - interval else: @@ -90,59 +108,79 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, return index -def _decode_video_frames_list(cfg, frames_list, vid_fps): +def _decode_video_frames_list(cfg, + frames_list, + vid_fps, + num_temporal_views_override=None): """ Decodes the video given the numpy frames. Args: cfg (Config): The global config object. frames_list (list): all frames for a video, the frames should be numpy array. vid_fps (int): the fps of this video. + num_temporal_views_override (int): the temporal clips per video Returns: frames (Tensor): video tensor data """ assert isinstance(frames_list, list) - num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + if num_temporal_views_override is not None: + num_clips_per_video = num_temporal_views_override + else: + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS frame_list = [] for clip_idx in range(num_clips_per_video): # for each clip in the video, # a list is generated before decoding the specified frames from the video list_ = _interval_based_sampling( - len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx, - num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, - cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) + len(frames_list), + vid_fps, + cfg.DATA.TARGET_FPS, + clip_idx, + num_clips_per_video, + cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, + cfg.DATA.MINUS_INTERVAL, + ) frames = None frames = torch.from_numpy( - np.stack([frames_list[l_index] for l_index in list_.tolist()], - axis=0)) + np.stack([frames_list[index] for index in list_.tolist()], axis=0)) frame_list.append(frames) frames = torch.stack(frame_list) - if num_clips_per_video == 1: - frames = frames.squeeze(0) - + del vr return frames -def _decode_video(cfg, path): +def _decode_video(cfg, path, num_temporal_views_override=None): """ Decodes the video given the numpy frames. Args: + cfg (Config): The global config object. path (str): video file path. + num_temporal_views_override (int): the temporal clips per video Returns: frames (Tensor): video tensor data """ vr = VideoReader(path) - - num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + if num_temporal_views_override is not None: + num_clips_per_video = num_temporal_views_override + else: + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS frame_list = [] for clip_idx in range(num_clips_per_video): # for each clip in the video, # a list is generated before decoding the specified frames from the video list_ = _interval_based_sampling( - len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx, - num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, - cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) + len(vr), + vr.get_avg_fps(), + cfg.DATA.TARGET_FPS, + clip_idx, + num_clips_per_video, + cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, + cfg.DATA.MINUS_INTERVAL, + ) frames = None if path.endswith('.avi'): append_list = torch.arange(0, list_[0], 4) @@ -155,8 +193,6 @@ def _decode_video(cfg, path): vr.get_batch(list_).to_dlpack()).clone() frame_list.append(frames) frames = torch.stack(frame_list) - if num_clips_per_video == 1: - frames = frames.squeeze(0) del vr return frames @@ -224,6 +260,29 @@ class KineticsResizedCrop(object): y = y_max // 2 return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] + def _get_random_crop(self, clip): + _, _, clip_height, clip_width = clip.shape + + short_side = min(clip_height, clip_width) + long_side = max(clip_height, clip_width) + new_short_side = int(random.uniform(*self.short_side_range)) + new_long_side = int(long_side / short_side * new_short_side) + if clip_height < clip_width: + new_clip_height = new_short_side + new_clip_width = new_long_side + else: + new_clip_height = new_long_side + new_clip_width = new_short_side + + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + + x_max = int(new_clip_width - self.crop_size) + y_max = int(new_clip_height - self.crop_size) + x = int(random.uniform(0, x_max)) + y = int(random.uniform(0, y_max)) + return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] + def set_spatial_index(self, idx): """Set the spatial cropping index for controlled cropping.. Args: diff --git a/tests/pipelines/test_hicossl_video_embedding.py b/tests/pipelines/test_hicossl_video_embedding.py new file mode 100644 index 00000000..5615cef2 --- /dev/null +++ b/tests/pipelines/test_hicossl_video_embedding.py @@ -0,0 +1,26 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# !/usr/bin/env python +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class HICOSSLVideoEmbeddingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_s3dg_video-embedding' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + videossl_pipeline = pipeline( + Tasks.video_embedding, model=self.model_id) + result = videossl_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'video embedding output: {result}.') + + +if __name__ == '__main__': + unittest.main()