添加新的action-detection task
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9898947
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92 | |||
| size 3191059 | |||
| @@ -133,6 +133,7 @@ class Pipelines(object): | |||
| skin_retouching = 'unet-skin-retouching' | |||
| tinynas_classification = 'tinynas-classification' | |||
| crowd_counting = 'hrnet-crowd-counting' | |||
| action_detection = 'ResNetC3D-action-detection' | |||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
| image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .action_detection_onnx import ActionDetONNX | |||
| else: | |||
| _import_structure = {'action_detection_onnx': ['ActionDetONNX']} | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,177 @@ | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import subprocess | |||
| import cv2 | |||
| import numpy as np | |||
| import onnxruntime as rt | |||
| from modelscope.models import Model | |||
| from modelscope.utils.constant import Devices | |||
| from modelscope.utils.device import verify_device | |||
| class ActionDetONNX(Model): | |||
| def __init__(self, model_dir, config, *args, **kwargs): | |||
| super().__init__(self, model_dir, *args, **kwargs) | |||
| model_file = osp.join(config['model_file']) | |||
| device_type, device_id = verify_device(self._device_name) | |||
| options = rt.SessionOptions() | |||
| options.intra_op_num_threads = 1 | |||
| options.inter_op_num_threads = 1 | |||
| if device_type == Devices.gpu: | |||
| sess = rt.InferenceSession( | |||
| model_file, | |||
| providers=['CUDAExecutionProvider'], | |||
| sess_options=options, | |||
| provider_options=[{ | |||
| 'device_id': device_id | |||
| }]) | |||
| else: | |||
| sess = rt.InferenceSession( | |||
| model_file, | |||
| providers=['CPUExecutionProvider'], | |||
| sess_options=options) | |||
| self.input_name = sess.get_inputs()[0].name | |||
| self.sess = sess | |||
| self.num_stride = len(config['fpn_strides']) | |||
| self.score_thresh = np.asarray( | |||
| config['pre_nms_thresh'], dtype='float32').reshape((1, -1)) | |||
| self.size_divisibility = config['size_divisibility'] | |||
| self.nms_threshold = config['nms_thresh'] | |||
| self.tmp_dir = config['tmp_dir'] | |||
| self.temporal_stride = config['step'] | |||
| self.input_data_type = config['input_type'] | |||
| self.action_names = config['action_names'] | |||
| self.video_length_limit = config['video_length_limit'] | |||
| def resize_box(self, det, height, width, scale_h, scale_w): | |||
| bboxs = det[0] | |||
| bboxs[:, [0, 2]] *= scale_w | |||
| bboxs[:, [1, 3]] *= scale_h | |||
| bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1) | |||
| bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1) | |||
| result = { | |||
| 'boxes': bboxs.round().astype('int32').tolist(), | |||
| 'scores': det[1].tolist(), | |||
| 'labels': [self.action_names[i] for i in det[2].tolist()] | |||
| } | |||
| return result | |||
| def parse_frames(self, frame_names): | |||
| imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names] | |||
| imgs = np.stack(imgs).astype(self.input_data_type).transpose( | |||
| (3, 0, 1, 2)) # c,t,h,w | |||
| imgs = imgs[None] | |||
| return imgs | |||
| def forward_img(self, imgs, h, w): | |||
| pred = self.sess.run(None, { | |||
| self.input_name: imgs, | |||
| 'height': np.asarray(h), | |||
| 'width': np.asarray(w) | |||
| }) | |||
| dets = self.post_nms( | |||
| pred, | |||
| score_threshold=self.score_thresh, | |||
| nms_threshold=self.nms_threshold) | |||
| return dets | |||
| def forward_video(self, video_name, scale): | |||
| min_size, max_size = self._get_sizes(scale) | |||
| tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||
| if osp.exists(tmp_dir): | |||
| shutil.rmtree(tmp_dir) | |||
| os.makedirs(tmp_dir) | |||
| frame_rate = 2 | |||
| cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \ | |||
| f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg' | |||
| cmd = cmd.split(' ') | |||
| subprocess.call(cmd) | |||
| frame_names = [ | |||
| osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir)) | |||
| if name.endswith('.jpg') | |||
| ] | |||
| frame_names = [ | |||
| frame_names[i:i + frame_rate * 2] | |||
| for i in range(0, | |||
| len(frame_names) - frame_rate * 2 + 1, frame_rate | |||
| * self.temporal_stride) | |||
| ] | |||
| timestamp = list( | |||
| range(1, | |||
| len(frame_names) * self.temporal_stride, | |||
| self.temporal_stride)) | |||
| batch_imgs = [self.parse_frames(names) for names in frame_names] | |||
| N, _, T, H, W = batch_imgs[0].shape | |||
| scale_min = min_size / min(H, W) | |||
| h, w = min(int(scale_min * H), | |||
| max_size), min(int(scale_min * W), max_size) | |||
| h = round(h / self.size_divisibility) * self.size_divisibility | |||
| w = round(w / self.size_divisibility) * self.size_divisibility | |||
| scale_h, scale_w = H / h, W / w | |||
| results = [] | |||
| for imgs in batch_imgs: | |||
| det = self.forward_img(imgs, h, w) | |||
| det = self.resize_box(det[0], H, W, scale_h, scale_w) | |||
| results.append(det) | |||
| results = [{ | |||
| 'timestamp': t, | |||
| 'actions': res | |||
| } for t, res in zip(timestamp, results)] | |||
| shutil.rmtree(tmp_dir) | |||
| return results | |||
| def forward(self, video_name): | |||
| return self.forward_video(video_name, scale=1) | |||
| def post_nms(self, pred, score_threshold, nms_threshold=0.3): | |||
| pred_bboxes, pred_scores = pred | |||
| N = len(pred_bboxes) | |||
| dets = [] | |||
| for i in range(N): | |||
| bboxes, scores = pred_bboxes[i], pred_scores[i] | |||
| candidate_inds = scores > score_threshold | |||
| scores = scores[candidate_inds] | |||
| candidate_nonzeros = candidate_inds.nonzero() | |||
| bboxes = bboxes[candidate_nonzeros[0]] | |||
| labels = candidate_nonzeros[1] | |||
| keep = self._nms(bboxes, scores, labels, nms_threshold) | |||
| bbox = bboxes[keep] | |||
| score = scores[keep] | |||
| label = labels[keep] | |||
| dets.append((bbox, score, label)) | |||
| return dets | |||
| def _nms(self, boxes, scores, idxs, nms_threshold): | |||
| if len(boxes) == 0: | |||
| return [] | |||
| max_coordinate = boxes.max() | |||
| offsets = idxs * (max_coordinate + 1) | |||
| boxes_for_nms = boxes + offsets[:, None].astype('float32') | |||
| boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0] | |||
| boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1] | |||
| keep = cv2.dnn.NMSBoxes( | |||
| boxes_for_nms.tolist(), | |||
| scores.tolist(), | |||
| score_threshold=0, | |||
| nms_threshold=nms_threshold) | |||
| if len(keep.shape) == 2: | |||
| keep = np.squeeze(keep, 1) | |||
| return keep | |||
| def _get_sizes(self, scale): | |||
| if scale == 1: | |||
| min_size, max_size = 512, 896 | |||
| elif scale == 2: | |||
| min_size, max_size = 768, 1280 | |||
| else: | |||
| min_size, max_size = 1024, 1792 | |||
| return min_size, max_size | |||
| @@ -35,6 +35,7 @@ class OutputKeys(object): | |||
| UUID = 'uuid' | |||
| WORD = 'word' | |||
| KWS_LIST = 'kws_list' | |||
| TIMESTAMPS = 'timestamps' | |||
| SPLIT_VIDEO_NUM = 'split_video_num' | |||
| SPLIT_META_DICT = 'split_meta_dict' | |||
| @@ -541,6 +542,19 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # { | |||
| # 'labels': ['吸烟', '打电话', '吸烟'], | |||
| # 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487], | |||
| # 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]], | |||
| # 'timestamps': [1, 3, 5] | |||
| # } | |||
| Tasks.action_detection: [ | |||
| OutputKeys.TIMESTAMPS, | |||
| OutputKeys.LABELS, | |||
| OutputKeys.SCORES, | |||
| OutputKeys.BOXES, | |||
| ], | |||
| # { | |||
| # 'output': [ | |||
| # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | |||
| @@ -551,6 +565,7 @@ TASK_OUTPUTS = { | |||
| # {'label': '13421097', 'score': 2.75914817393641e-06}]] | |||
| # } | |||
| Tasks.faq_question_answering: [OutputKeys.OUTPUT], | |||
| # image person reid result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| @@ -71,6 +71,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
| Tasks.action_recognition: (Pipelines.action_recognition, | |||
| 'damo/cv_TAdaConv_action-recognition'), | |||
| Tasks.action_detection: (Pipelines.action_detection, | |||
| 'damo/cv_ResNetC3D_action-detection_detection2d'), | |||
| Tasks.live_category: (Pipelines.live_category, | |||
| 'damo/cv_resnet50_live-category'), | |||
| Tasks.video_category: (Pipelines.video_category, | |||
| @@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .action_detection_pipeline import ActionDetectionPipeline | |||
| from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
| from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||
| from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | |||
| @@ -48,6 +49,7 @@ if TYPE_CHECKING: | |||
| else: | |||
| _import_structure = { | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| 'action_detection_pipeline': ['ActionDetectionPipeline'], | |||
| 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
| 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||
| 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | |||
| @@ -0,0 +1,63 @@ | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.action_detection import ActionDetONNX | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.action_detection, module_name=Pipelines.action_detection) | |||
| class ActionDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a action detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE) | |||
| logger.info(f'loading model from {model_path}') | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.cfg.MODEL.model_file = model_path | |||
| self.model = ActionDetONNX(self.model, self.cfg.MODEL, | |||
| self.device_name) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| video_name = input | |||
| else: | |||
| raise TypeError(f'input should be a str,' | |||
| f' but got {type(input)}') | |||
| result = {'video_name': video_name} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| preds = self.model.forward(input['video_name']) | |||
| labels = sum([pred['actions']['labels'] for pred in preds], []) | |||
| scores = sum([pred['actions']['scores'] for pred in preds], []) | |||
| boxes = sum([pred['actions']['boxes'] for pred in preds], []) | |||
| timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels']) | |||
| for pred in preds], []) | |||
| out = { | |||
| OutputKeys.TIMESTAMPS: timestamps, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.BOXES: boxes | |||
| } | |||
| return out | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -58,6 +58,7 @@ class CVTasks(object): | |||
| # video recognition | |||
| live_category = 'live-category' | |||
| action_recognition = 'action-recognition' | |||
| action_detection = 'action-detection' | |||
| video_category = 'video-category' | |||
| video_embedding = 'video-embedding' | |||
| virtual_try_on = 'virtual-try-on' | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ActionDetectionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| action_detection_pipline = pipeline( | |||
| Tasks.action_detection, | |||
| model='damo/cv_ResNetC3D_action-detection_detection2d') | |||
| result = action_detection_pipline( | |||
| 'data/test/videos/action_detection_test_video.mp4') | |||
| print('action detection results:', result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||