panoptic segmentation 模型接入
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9758389
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a | |||||
| size 245864 | |||||
| @@ -20,6 +20,7 @@ class Models(object): | |||||
| product_retrieval_embedding = 'product-retrieval-embedding' | product_retrieval_embedding = 'product-retrieval-embedding' | ||||
| body_2d_keypoints = 'body-2d-keypoints' | body_2d_keypoints = 'body-2d-keypoints' | ||||
| crowd_counting = 'HRNetCrowdCounting' | crowd_counting = 'HRNetCrowdCounting' | ||||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | |||||
| image_reid_person = 'passvitb' | image_reid_person = 'passvitb' | ||||
| video_summarization = 'pgl-video-summarization' | video_summarization = 'pgl-video-summarization' | ||||
| @@ -114,6 +115,7 @@ class Pipelines(object): | |||||
| tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
| crowd_counting = 'hrnet-crowd-counting' | crowd_counting = 'hrnet-crowd-counting' | ||||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | ||||
| image_panoptic_segmentation = 'image-panoptic-segmentation' | |||||
| video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| @@ -3,8 +3,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| cartoon, cmdssl_video_embedding, crowd_counting, face_detection, | cartoon, cmdssl_video_embedding, crowd_counting, face_detection, | ||||
| face_generation, image_classification, image_color_enhance, | face_generation, image_classification, image_color_enhance, | ||||
| image_colorization, image_denoise, image_instance_segmentation, | image_colorization, image_denoise, image_instance_segmentation, | ||||
| image_portrait_enhancement, image_reid_person, | |||||
| image_to_image_generation, image_to_image_translation, | |||||
| object_detection, product_retrieval_embedding, | |||||
| salient_detection, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | |||||
| image_panoptic_segmentation, image_portrait_enhancement, | |||||
| image_reid_person, image_to_image_generation, | |||||
| image_to_image_translation, object_detection, | |||||
| product_retrieval_embedding, salient_detection, | |||||
| super_resolution, video_single_object_tracking, | |||||
| video_summarization, virual_tryon) | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .panseg_model import SwinLPanopticSegmentation | |||||
| else: | |||||
| _import_structure = { | |||||
| 'panseg_model': ['SwinLPanopticSegmentation'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,54 @@ | |||||
| import os.path as osp | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base.base_torch_model import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| @MODELS.register_module( | |||||
| Tasks.image_segmentation, module_name=Models.panoptic_segmentation) | |||||
| class SwinLPanopticSegmentation(TorchModel): | |||||
| def __init__(self, model_dir: str, **kwargs): | |||||
| """str -- model file root.""" | |||||
| super().__init__(model_dir, **kwargs) | |||||
| from mmcv.runner import load_checkpoint | |||||
| import mmcv | |||||
| from mmdet.models import build_detector | |||||
| config = osp.join(model_dir, 'config.py') | |||||
| cfg = mmcv.Config.fromfile(config) | |||||
| if 'pretrained' in cfg.model: | |||||
| cfg.model.pretrained = None | |||||
| elif 'init_cfg' in cfg.model.backbone: | |||||
| cfg.model.backbone.init_cfg = None | |||||
| # build model | |||||
| cfg.model.train_cfg = None | |||||
| self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |||||
| # load model | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| checkpoint = load_checkpoint( | |||||
| self.model, model_path, map_location='cpu') | |||||
| self.CLASSES = checkpoint['meta']['CLASSES'] | |||||
| self.num_classes = len(self.CLASSES) | |||||
| self.cfg = cfg | |||||
| def inference(self, data): | |||||
| """data is dict,contain img and img_metas,follow with mmdet.""" | |||||
| with torch.no_grad(): | |||||
| results = self.model(return_loss=False, rescale=True, **data) | |||||
| return results | |||||
| def forward(self, Inputs): | |||||
| import pdb | |||||
| pdb.set_trace() | |||||
| return self.model(**Inputs) | |||||
| @@ -23,6 +23,7 @@ if TYPE_CHECKING: | |||||
| from .image_denoise_pipeline import ImageDenoisePipeline | from .image_denoise_pipeline import ImageDenoisePipeline | ||||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | ||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline | |||||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | ||||
| from .image_reid_person_pipeline import ImageReidPersonPipeline | from .image_reid_person_pipeline import ImageReidPersonPipeline | ||||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | from .image_style_transfer_pipeline import ImageStyleTransferPipeline | ||||
| @@ -37,6 +38,7 @@ if TYPE_CHECKING: | |||||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | from .virtual_try_on_pipeline import VirtualTryonPipeline | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
| @@ -59,6 +61,8 @@ else: | |||||
| 'image_instance_segmentation_pipeline': | 'image_instance_segmentation_pipeline': | ||||
| ['ImageInstanceSegmentationPipeline'], | ['ImageInstanceSegmentationPipeline'], | ||||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | 'image_matting_pipeline': ['ImageMattingPipeline'], | ||||
| 'image_panoptic_segmentation_pipeline': | |||||
| ['ImagePanopticSegmentationPipeline'], | |||||
| 'image_portrait_enhancement_pipeline': | 'image_portrait_enhancement_pipeline': | ||||
| ['ImagePortraitEnhancementPipeline'], | ['ImagePortraitEnhancementPipeline'], | ||||
| 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | ||||
| @@ -0,0 +1,103 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_segmentation, | |||||
| module_name=Pipelines.image_panoptic_segmentation) | |||||
| class ImagePanopticSegmentationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a image panoptic segmentation pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| logger.info('panoptic segmentation model, pipeline init') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| from mmdet.datasets.pipelines import Compose | |||||
| from mmcv.parallel import collate, scatter | |||||
| from mmdet.datasets import replace_ImageToTensor | |||||
| cfg = self.model.cfg | |||||
| # build the data pipeline | |||||
| if isinstance(input, str): | |||||
| # input is str, file names, pipeline loadimagefromfile | |||||
| # collect data | |||||
| data = dict(img_info=dict(filename=input), img_prefix=None) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||||
| img = np.array(input.convert('RGB')) | |||||
| # collect data | |||||
| data = dict(img=img) | |||||
| elif isinstance(input, np.ndarray): | |||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
| else: | |||||
| img = input | |||||
| img = img[:, :, ::-1] # in rgb order | |||||
| # collect data | |||||
| data = dict(img=img) | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |||||
| test_pipeline = Compose(cfg.data.test.pipeline) | |||||
| data = test_pipeline(data) | |||||
| # copy from mmdet_model collect data | |||||
| data = collate([data], samples_per_gpu=1) | |||||
| data['img_metas'] = [ | |||||
| img_metas.data[0] for img_metas in data['img_metas'] | |||||
| ] | |||||
| data['img'] = [img.data[0] for img in data['img']] | |||||
| if next(self.model.parameters()).is_cuda: | |||||
| # scatter to specified GPU | |||||
| data = scatter(data, [next(self.model.parameters()).device])[0] | |||||
| return data | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| results = self.model.inference(input) | |||||
| return results | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| # bz=1, tcguo | |||||
| pan_results = inputs[0]['pan_results'] | |||||
| INSTANCE_OFFSET = 1000 | |||||
| ids = np.unique(pan_results)[::-1] | |||||
| legal_indices = ids != self.model.num_classes # for VOID label | |||||
| ids = ids[legal_indices] | |||||
| labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |||||
| segms = (pan_results[None] == ids[:, None, None]) | |||||
| masks = [it.astype(np.int) for it in segms] | |||||
| labels_txt = np.array(self.model.CLASSES)[labels].tolist() | |||||
| outputs = { | |||||
| OutputKeys.MASKS: masks, | |||||
| OutputKeys.LABELS: labels_txt, | |||||
| OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] | |||||
| } | |||||
| return outputs | |||||
| @@ -134,3 +134,22 @@ def show_video_tracking_result(video_in_path, bboxes, video_save_path): | |||||
| video_writer.write(frame) | video_writer.write(frame) | ||||
| video_writer.release | video_writer.release | ||||
| cap.release() | cap.release() | ||||
| def panoptic_seg_masks_to_image(masks): | |||||
| draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) | |||||
| from mmdet.core.visualization.palette import get_palette | |||||
| mask_palette = get_palette('coco', 133) | |||||
| from mmdet.core.visualization.image import _get_bias_color | |||||
| taken_colors = set([0, 0, 0]) | |||||
| for i, mask in enumerate(masks): | |||||
| color_mask = mask_palette[i] | |||||
| while tuple(color_mask) in taken_colors: | |||||
| color_mask = _get_bias_color(color_mask) | |||||
| taken_colors.add(tuple(color_mask)) | |||||
| mask = mask.astype(bool) | |||||
| draw_img[mask] = color_mask | |||||
| return draw_img | |||||
| @@ -0,0 +1,40 @@ | |||||
| import unittest | |||||
| import cv2 | |||||
| import PIL | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import panoptic_seg_masks_to_image | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImagePanopticSegmentationTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_panoptic_segmentation(self): | |||||
| input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||||
| model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||||
| pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||||
| result = pan_segmentor(input_location) | |||||
| draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||||
| cv2.imwrite('result.jpg', draw_img) | |||||
| print('print test_image_panoptic_segmentation return success') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_panoptic_segmentation_from_PIL(self): | |||||
| input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||||
| model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||||
| pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||||
| PIL_array = PIL.Image.open(input_location) | |||||
| result = pan_segmentor(PIL_array) | |||||
| draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||||
| cv2.imwrite('result.jpg', draw_img) | |||||
| print('print test_image_panoptic_segmentation from PIL return success') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||