Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10319306master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 | |||||
| size 53219 | |||||
| @@ -1,3 +0,0 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||||
| size 1702339 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd | |||||
| size 167407 | |||||
| @@ -40,6 +40,7 @@ class Models(object): | |||||
| mtcnn = 'mtcnn' | mtcnn = 'mtcnn' | ||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| @@ -49,6 +50,7 @@ class Models(object): | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| segformer = 'Segformer' | segformer = 'Segformer' | ||||
| image_object_detection_auto = 'image-object-detection-auto' | |||||
| # nlp models | # nlp models | ||||
| bert = 'bert' | bert = 'bert' | ||||
| @@ -170,6 +172,7 @@ class Pipelines(object): | |||||
| ocr_recognition = 'convnextTiny-ocr-recognition' | ocr_recognition = 'convnextTiny-ocr-recognition' | ||||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | image_portrait_enhancement = 'gpen-image-portrait-enhancement' | ||||
| image_to_image_generation = 'image-to-image-generation' | image_to_image_generation = 'image-to-image-generation' | ||||
| image_object_detection_auto = 'yolox_image-object-detection-auto' | |||||
| skin_retouching = 'unet-skin-retouching' | skin_retouching = 'unet-skin-retouching' | ||||
| tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| @@ -185,6 +188,7 @@ class Pipelines(object): | |||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' | |||||
| pst_action_recognition = 'patchshift-action-recognition' | pst_action_recognition = 'patchshift-action-recognition' | ||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| @@ -427,6 +431,7 @@ class Datasets(object): | |||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | Face2dKeypointsDataset = 'Face2dKeypointsDataset' | ||||
| HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' | ||||
| @@ -4,15 +4,15 @@ | |||||
| from . import (action_recognition, animal_recognition, body_2d_keypoints, | from . import (action_recognition, animal_recognition, body_2d_keypoints, | ||||
| body_3d_keypoints, cartoon, cmdssl_video_embedding, | body_3d_keypoints, cartoon, cmdssl_video_embedding, | ||||
| crowd_counting, face_2d_keypoints, face_detection, | crowd_counting, face_2d_keypoints, face_detection, | ||||
| face_generation, image_classification, image_color_enhance, | |||||
| image_colorization, image_denoise, image_inpainting, | |||||
| image_instance_segmentation, image_panoptic_segmentation, | |||||
| image_portrait_enhancement, image_reid_person, | |||||
| image_semantic_segmentation, image_to_image_generation, | |||||
| image_to_image_translation, movie_scene_segmentation, | |||||
| object_detection, product_retrieval_embedding, | |||||
| realtime_object_detection, salient_detection, shop_segmentation, | |||||
| super_resolution, video_single_object_tracking, | |||||
| video_summarization, virual_tryon) | |||||
| face_generation, human_wholebody_keypoint, image_classification, | |||||
| image_color_enhance, image_colorization, image_denoise, | |||||
| image_inpainting, image_instance_segmentation, | |||||
| image_panoptic_segmentation, image_portrait_enhancement, | |||||
| image_reid_person, image_semantic_segmentation, | |||||
| image_to_image_generation, image_to_image_translation, | |||||
| movie_scene_segmentation, object_detection, | |||||
| product_retrieval_embedding, realtime_object_detection, | |||||
| salient_detection, shop_segmentation, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | |||||
| # yapf: enable | # yapf: enable | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .human_wholebody_keypoint import HumanWholeBodyKeypoint | |||||
| else: | |||||
| _import_structure = { | |||||
| 'human_wholebody_keypoint': ['HumanWholeBodyKeypoint'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,17 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from easycv.models.pose.top_down import TopDown | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||||
| from modelscope.utils.constant import Tasks | |||||
| @MODELS.register_module( | |||||
| group_key=Tasks.human_wholebody_keypoint, | |||||
| module_name=Models.human_wholebody_keypoint) | |||||
| class HumanWholeBodyKeypoint(EasyCVBaseModel, TopDown): | |||||
| def __init__(self, model_dir=None, *args, **kwargs): | |||||
| EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||||
| TopDown.__init__(self, *args, **kwargs) | |||||
| @@ -10,7 +10,7 @@ if TYPE_CHECKING: | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'mmdet_model': ['DetectionModel'], | 'mmdet_model': ['DetectionModel'], | ||||
| 'yolox_pai': ['YOLOX'] | |||||
| 'yolox_pai': ['YOLOX'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -9,6 +9,9 @@ from modelscope.utils.constant import Tasks | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| group_key=Tasks.image_object_detection, module_name=Models.yolox) | group_key=Tasks.image_object_detection, module_name=Models.yolox) | ||||
| @MODELS.register_module( | |||||
| group_key=Tasks.image_object_detection, | |||||
| module_name=Models.image_object_detection_auto) | |||||
| class YOLOX(EasyCVBaseModel, _YOLOX): | class YOLOX(EasyCVBaseModel, _YOLOX): | ||||
| def __init__(self, model_dir=None, *args, **kwargs): | def __init__(self, model_dir=None, *args, **kwargs): | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .human_wholebody_keypoint_dataset import WholeBodyCocoTopDownDataset | |||||
| else: | |||||
| _import_structure = { | |||||
| 'human_wholebody_keypoint_dataset': ['WholeBodyCocoTopDownDataset'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from easycv.datasets.pose import \ | |||||
| WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset | |||||
| from modelscope.metainfo import Datasets | |||||
| from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset | |||||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
| from modelscope.utils.constant import Tasks | |||||
| @TASK_DATASETS.register_module( | |||||
| group_key=Tasks.human_wholebody_keypoint, | |||||
| module_name=Datasets.HumanWholeBodyKeypointDataset) | |||||
| class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, | |||||
| _WholeBodyCocoTopDownDataset): | |||||
| """EasyCV dataset for human whole body 2d keypoints. | |||||
| Args: | |||||
| split_config (dict): Dataset root path from MSDataset, e.g. | |||||
| {"train":"local cache path"} or {"evaluation":"local cache path"}. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. Not support yet. | |||||
| mode: Training or Evaluation. | |||||
| """ | |||||
| def __init__(self, | |||||
| split_config=None, | |||||
| preprocessor=None, | |||||
| mode=None, | |||||
| *args, | |||||
| **kwargs) -> None: | |||||
| EasyCVBaseDataset.__init__( | |||||
| self, | |||||
| split_config=split_config, | |||||
| preprocessor=preprocessor, | |||||
| mode=mode, | |||||
| args=args, | |||||
| kwargs=kwargs) | |||||
| _WholeBodyCocoTopDownDataset.__init__(self, *args, **kwargs) | |||||
| @@ -203,7 +203,7 @@ TASK_OUTPUTS = { | |||||
| # human body keypoints detection result for single sample | # human body keypoints detection result for single sample | ||||
| # { | # { | ||||
| # "poses": [ | |||||
| # "keypoints": [ | |||||
| # [[x, y]*15], | # [[x, y]*15], | ||||
| # [[x, y]*15], | # [[x, y]*15], | ||||
| # [[x, y]*15] | # [[x, y]*15] | ||||
| @@ -220,7 +220,7 @@ TASK_OUTPUTS = { | |||||
| # ] | # ] | ||||
| # } | # } | ||||
| Tasks.body_2d_keypoints: | Tasks.body_2d_keypoints: | ||||
| [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | |||||
| [OutputKeys.KEYPOINTS, OutputKeys.SCORES, OutputKeys.BOXES], | |||||
| # 3D human body keypoints detection result for single sample | # 3D human body keypoints detection result for single sample | ||||
| # { | # { | ||||
| @@ -339,6 +339,21 @@ TASK_OUTPUTS = { | |||||
| OutputKeys.SCENE_META_LIST | OutputKeys.SCENE_META_LIST | ||||
| ], | ], | ||||
| # human whole body keypoints detection result for single sample | |||||
| # { | |||||
| # "keypoints": [ | |||||
| # [[x, y]*133], | |||||
| # [[x, y]*133], | |||||
| # [[x, y]*133] | |||||
| # ] | |||||
| # "boxes": [ | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # ] | |||||
| # } | |||||
| Tasks.human_wholebody_keypoint: [OutputKeys.KEYPOINTS, OutputKeys.BOXES], | |||||
| # video summarization result for a single video | # video summarization result for a single video | ||||
| # { | # { | ||||
| # "output": | # "output": | ||||
| @@ -75,8 +75,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_bart_text-error-correction_chinese'), | 'damo/nlp_bart_text-error-correction_chinese'), | ||||
| Tasks.image_captioning: (Pipelines.image_captioning, | Tasks.image_captioning: (Pipelines.image_captioning, | ||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_body_reshaping: (Pipelines.image_body_reshaping, | |||||
| 'damo/cv_flow-based-body-reshaping_damo'), | |||||
| Tasks.image_portrait_stylization: | Tasks.image_portrait_stylization: | ||||
| (Pipelines.person_image_cartoon, | (Pipelines.person_image_cartoon, | ||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| @@ -159,6 +157,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.image_classification: | Tasks.image_classification: | ||||
| (Pipelines.daily_image_classification, | (Pipelines.daily_image_classification, | ||||
| 'damo/cv_vit-base_image-classification_Dailylife-labels'), | 'damo/cv_vit-base_image-classification_Dailylife-labels'), | ||||
| Tasks.image_object_detection: | |||||
| (Pipelines.image_object_detection_auto, | |||||
| 'damo/cv_yolox_image-object-detection-auto'), | |||||
| Tasks.ocr_recognition: | Tasks.ocr_recognition: | ||||
| (Pipelines.ocr_recognition, | (Pipelines.ocr_recognition, | ||||
| 'damo/cv_convnextTiny_ocr-recognition-general_damo'), | 'damo/cv_convnextTiny_ocr-recognition-general_damo'), | ||||
| @@ -186,6 +187,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_fft_inpainting_lama'), | 'damo/cv_fft_inpainting_lama'), | ||||
| Tasks.video_inpainting: (Pipelines.video_inpainting, | Tasks.video_inpainting: (Pipelines.video_inpainting, | ||||
| 'damo/cv_video-inpainting'), | 'damo/cv_video-inpainting'), | ||||
| Tasks.human_wholebody_keypoint: | |||||
| (Pipelines.human_wholebody_keypoint, | |||||
| 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), | |||||
| Tasks.hand_static: (Pipelines.hand_static, | Tasks.hand_static: (Pipelines.hand_static, | ||||
| 'damo/cv_mobileface_hand-static'), | 'damo/cv_mobileface_hand-static'), | ||||
| Tasks.face_human_hand_detection: | Tasks.face_human_hand_detection: | ||||
| @@ -46,7 +46,10 @@ if TYPE_CHECKING: | |||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | from .virtual_try_on_pipeline import VirtualTryonPipeline | ||||
| from .shop_segmentation_pipleline import ShopSegmentationPipeline | from .shop_segmentation_pipleline import ShopSegmentationPipeline | ||||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | |||||
| from .easycv_pipelines import (EasyCVDetectionPipeline, | |||||
| EasyCVSegmentationPipeline, | |||||
| Face2DKeypointsPipeline, | |||||
| HumanWholebodyKeypointsPipeline) | |||||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | ||||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | ||||
| from .mog_face_detection_pipeline import MogFaceDetectionPipeline | from .mog_face_detection_pipeline import MogFaceDetectionPipeline | ||||
| @@ -109,8 +112,10 @@ else: | |||||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | ||||
| 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], | 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], | ||||
| 'easycv_pipeline': [ | 'easycv_pipeline': [ | ||||
| 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', | |||||
| 'Face2DKeypointsPipeline' | |||||
| 'EasyCVDetectionPipeline', | |||||
| 'EasyCVSegmentationPipeline', | |||||
| 'Face2DKeypointsPipeline', | |||||
| 'HumanWholebodyKeypointsPipeline', | |||||
| ], | ], | ||||
| 'text_driven_segmentation_pipeline': | 'text_driven_segmentation_pipeline': | ||||
| ['TextDrivenSegmentationPipeline'], | ['TextDrivenSegmentationPipeline'], | ||||
| @@ -73,7 +73,7 @@ class Body2DKeypointsPipeline(Pipeline): | |||||
| if input[0] is None or input[1] is None: | if input[0] is None or input[1] is None: | ||||
| return { | return { | ||||
| OutputKeys.BOXES: [], | OutputKeys.BOXES: [], | ||||
| OutputKeys.POSES: [], | |||||
| OutputKeys.KEYPOINTS: [], | |||||
| OutputKeys.SCORES: [] | OutputKeys.SCORES: [] | ||||
| } | } | ||||
| @@ -83,7 +83,7 @@ class Body2DKeypointsPipeline(Pipeline): | |||||
| result_boxes.append([box[0][0], box[0][1], box[1][0], box[1][1]]) | result_boxes.append([box[0][0], box[0][1], box[1][0], box[1][1]]) | ||||
| return { | return { | ||||
| OutputKeys.BOXES: result_boxes, | OutputKeys.BOXES: result_boxes, | ||||
| OutputKeys.POSES: poses, | |||||
| OutputKeys.KEYPOINTS: poses, | |||||
| OutputKeys.SCORES: scores | OutputKeys.SCORES: scores | ||||
| } | } | ||||
| @@ -145,7 +145,7 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| kps_2d = self.human_body_2d_kps_detector(frame) | kps_2d = self.human_body_2d_kps_detector(frame) | ||||
| box = kps_2d['boxes'][ | box = kps_2d['boxes'][ | ||||
| 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox | 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox | ||||
| pose = kps_2d['poses'][0] # keypoints: [15, 2] | |||||
| pose = kps_2d['keypoints'][0] # keypoints: [15, 2] | |||||
| score = kps_2d['scores'][0] # keypoints: [15, 2] | score = kps_2d['scores'][0] # keypoints: [15, 2] | ||||
| all_2d_poses.append(pose) | all_2d_poses.append(pose) | ||||
| all_boxes_with_socre.append( | all_boxes_with_socre.append( | ||||
| @@ -7,11 +7,14 @@ if TYPE_CHECKING: | |||||
| from .detection_pipeline import EasyCVDetectionPipeline | from .detection_pipeline import EasyCVDetectionPipeline | ||||
| from .segmentation_pipeline import EasyCVSegmentationPipeline | from .segmentation_pipeline import EasyCVSegmentationPipeline | ||||
| from .face_2d_keypoints_pipeline import Face2DKeypointsPipeline | from .face_2d_keypoints_pipeline import Face2DKeypointsPipeline | ||||
| from .human_wholebody_keypoint_pipeline import HumanWholebodyKeypointsPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'detection_pipeline': ['EasyCVDetectionPipeline'], | 'detection_pipeline': ['EasyCVDetectionPipeline'], | ||||
| 'segmentation_pipeline': ['EasyCVSegmentationPipeline'], | 'segmentation_pipeline': ['EasyCVSegmentationPipeline'], | ||||
| 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'] | |||||
| 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'], | |||||
| 'human_wholebody_keypoint_pipeline': | |||||
| ['HumanWholebodyKeypointsPipeline'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -1,16 +1,28 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.cv.image_utils import \ | |||||
| show_image_object_detection_auto_result | |||||
| from .base import EasyCVPipeline | from .base import EasyCVPipeline | ||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_object_detection, module_name=Pipelines.easycv_detection) | Tasks.image_object_detection, module_name=Pipelines.easycv_detection) | ||||
| @PIPELINES.register_module( | |||||
| Tasks.image_object_detection, | |||||
| module_name=Pipelines.image_object_detection_auto) | |||||
| class EasyCVDetectionPipeline(EasyCVPipeline): | class EasyCVDetectionPipeline(EasyCVPipeline): | ||||
| """Pipeline for easycv detection task.""" | """Pipeline for easycv detection task.""" | ||||
| def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||||
| def __init__(self, | |||||
| model: str, | |||||
| model_file_pattern=ModelFile.TORCH_MODEL_FILE, | |||||
| *args, | |||||
| **kwargs): | |||||
| """ | """ | ||||
| model (str): model id on modelscope hub or local model path. | model (str): model id on modelscope hub or local model path. | ||||
| model_file_pattern (str): model file pattern. | model_file_pattern (str): model file pattern. | ||||
| @@ -21,3 +33,28 @@ class EasyCVDetectionPipeline(EasyCVPipeline): | |||||
| model_file_pattern=model_file_pattern, | model_file_pattern=model_file_pattern, | ||||
| *args, | *args, | ||||
| **kwargs) | **kwargs) | ||||
| def show_result(self, img_path, result, save_path=None): | |||||
| show_image_object_detection_auto_result(img_path, result, save_path) | |||||
| def __call__(self, inputs) -> Any: | |||||
| outputs = self.predict_op(inputs) | |||||
| scores = [] | |||||
| labels = [] | |||||
| boxes = [] | |||||
| for output in outputs: | |||||
| for score, label, box in zip(output['detection_scores'], | |||||
| output['detection_classes'], | |||||
| output['detection_boxes']): | |||||
| scores.append(score) | |||||
| labels.append(self.cfg.CLASSES[label]) | |||||
| boxes.append([b for b in box]) | |||||
| results = [{ | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: boxes | |||||
| } for output in outputs] | |||||
| return results | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path | |||||
| from typing import Any | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .base import EasyCVPipeline | |||||
| @PIPELINES.register_module( | |||||
| Tasks.human_wholebody_keypoint, | |||||
| module_name=Pipelines.human_wholebody_keypoint) | |||||
| class HumanWholebodyKeypointsPipeline(EasyCVPipeline): | |||||
| """Pipeline for human wholebody 2d keypoints detection.""" | |||||
| def __init__(self, | |||||
| model: str, | |||||
| model_file_pattern=ModelFile.TORCH_MODEL_FILE, | |||||
| *args, | |||||
| **kwargs): | |||||
| """ | |||||
| model (str): model id on modelscope hub or local model path. | |||||
| model_file_pattern (str): model file pattern. | |||||
| """ | |||||
| self.model_dir = model | |||||
| super(HumanWholebodyKeypointsPipeline, self).__init__( | |||||
| model=model, | |||||
| model_file_pattern=model_file_pattern, | |||||
| *args, | |||||
| **kwargs) | |||||
| def _build_predict_op(self, **kwargs): | |||||
| """Build EasyCV predictor.""" | |||||
| from easycv.predictors.builder import build_predictor | |||||
| detection_predictor_type = self.cfg['DETECTION']['type'] | |||||
| detection_model_path = os.path.join( | |||||
| self.model_dir, self.cfg['DETECTION']['model_path']) | |||||
| detection_cfg_file = os.path.join(self.model_dir, | |||||
| self.cfg['DETECTION']['config_file']) | |||||
| detection_score_threshold = self.cfg['DETECTION']['score_threshold'] | |||||
| self.cfg.pipeline.predictor_config[ | |||||
| 'detection_predictor_config'] = dict( | |||||
| type=detection_predictor_type, | |||||
| model_path=detection_model_path, | |||||
| config_file=detection_cfg_file, | |||||
| score_threshold=detection_score_threshold) | |||||
| easycv_config = self._to_easycv_config() | |||||
| pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { | |||||
| 'model_path': self.model_path, | |||||
| 'config_file': easycv_config, | |||||
| **kwargs | |||||
| }) | |||||
| return pipeline_op | |||||
| def __call__(self, inputs) -> Any: | |||||
| outputs = self.predict_op(inputs) | |||||
| results = [{ | |||||
| OutputKeys.KEYPOINTS: output['keypoints'], | |||||
| OutputKeys.BOXES: output['boxes'] | |||||
| } for output in outputs] | |||||
| return results | |||||
| @@ -29,6 +29,7 @@ class CVTasks(object): | |||||
| body_3d_keypoints = 'body-3d-keypoints' | body_3d_keypoints = 'body-3d-keypoints' | ||||
| hand_2d_keypoints = 'hand-2d-keypoints' | hand_2d_keypoints = 'hand-2d-keypoints' | ||||
| general_recognition = 'general-recognition' | general_recognition = 'general-recognition' | ||||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| image_multilabel_classification = 'image-multilabel-classification' | image_multilabel_classification = 'image-multilabel-classification' | ||||
| @@ -80,7 +80,7 @@ def realtime_object_detection_bbox_vis(image, bboxes): | |||||
| def draw_keypoints(output, original_image): | def draw_keypoints(output, original_image): | ||||
| poses = np.array(output[OutputKeys.POSES]) | |||||
| poses = np.array(output[OutputKeys.KEYPOINTS]) | |||||
| scores = np.array(output[OutputKeys.SCORES]) | scores = np.array(output[OutputKeys.SCORES]) | ||||
| boxes = np.array(output[OutputKeys.BOXES]) | boxes = np.array(output[OutputKeys.BOXES]) | ||||
| assert len(poses) == len(scores) and len(poses) == len(boxes) | assert len(poses) == len(scores) and len(poses) == len(boxes) | ||||
| @@ -234,3 +234,35 @@ def show_video_summarization_result(video_in_path, result, video_save_path): | |||||
| video_writer.write(frame) | video_writer.write(frame) | ||||
| video_writer.release() | video_writer.release() | ||||
| cap.release() | cap.release() | ||||
| def show_image_object_detection_auto_result(img_path, | |||||
| detection_result, | |||||
| save_path=None): | |||||
| scores = detection_result[OutputKeys.SCORES] | |||||
| labels = detection_result[OutputKeys.LABELS] | |||||
| bboxes = detection_result[OutputKeys.BOXES] | |||||
| img = cv2.imread(img_path) | |||||
| assert img is not None, f"Can't read img: {img_path}" | |||||
| for (score, label, box) in zip(scores, labels, bboxes): | |||||
| cv2.rectangle(img, (int(box[0]), int(box[1])), | |||||
| (int(box[2]), int(box[3])), (0, 0, 255), 2) | |||||
| cv2.putText( | |||||
| img, | |||||
| f'{score:.2f}', (int(box[0]), int(box[1])), | |||||
| 1, | |||||
| 1.0, (0, 255, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| cv2.putText( | |||||
| img, | |||||
| label, (int((box[0] + box[2]) * 0.5), int(box[1])), | |||||
| 1, | |||||
| 1.0, (0, 255, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| if save_path is not None: | |||||
| cv2.imwrite(save_path, img) | |||||
| return img | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_human_wholebody_keypoint(self): | |||||
| img_path = 'data/test/images/keypoints_detect/img_test_wholebody.jpg' | |||||
| model_id = 'damo/cv_hrnetw48_human-wholebody-keypoint_image' | |||||
| human_wholebody_keypoint_pipeline = pipeline( | |||||
| task=Tasks.human_wholebody_keypoint, model=model_id) | |||||
| output = human_wholebody_keypoint_pipeline(img_path)[0] | |||||
| output_keypoints = output[OutputKeys.KEYPOINTS] | |||||
| output_pose = output[OutputKeys.BOXES] | |||||
| human_wholebody_keypoint_pipeline.predict_op.show_result( | |||||
| img_path, | |||||
| output_keypoints, | |||||
| output_pose, | |||||
| scale=1, | |||||
| save_path='human_wholebody_keypoint_ret.jpg') | |||||
| for keypoint in output_keypoints: | |||||
| self.assertEqual(keypoint.shape[0], 133) | |||||
| for box in output_pose: | |||||
| self.assertEqual(box.shape[0], 4) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -59,6 +59,18 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_object_detection_auto_pipeline(self): | |||||
| model_id = 'damo/cv_yolox_image-object-detection-auto' | |||||
| test_image = 'data/test/images/auto_demo.jpg' | |||||
| image_object_detection_auto = pipeline( | |||||
| Tasks.image_object_detection, model=model_id) | |||||
| result = image_object_detection_auto(test_image)[0] | |||||
| image_object_detection_auto.show_result(test_image, result, | |||||
| 'auto_demo_ret.jpg') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||