panoptic segmentation 模型接入
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9758389
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a | |||
| size 245864 | |||
| @@ -20,6 +20,7 @@ class Models(object): | |||
| product_retrieval_embedding = 'product-retrieval-embedding' | |||
| body_2d_keypoints = 'body-2d-keypoints' | |||
| crowd_counting = 'HRNetCrowdCounting' | |||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
| image_reid_person = 'passvitb' | |||
| video_summarization = 'pgl-video-summarization' | |||
| @@ -114,6 +115,7 @@ class Pipelines(object): | |||
| tinynas_classification = 'tinynas-classification' | |||
| crowd_counting = 'hrnet-crowd-counting' | |||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
| image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| @@ -3,8 +3,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| cartoon, cmdssl_video_embedding, crowd_counting, face_detection, | |||
| face_generation, image_classification, image_color_enhance, | |||
| image_colorization, image_denoise, image_instance_segmentation, | |||
| image_portrait_enhancement, image_reid_person, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, | |||
| salient_detection, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_to_image_generation, | |||
| image_to_image_translation, object_detection, | |||
| product_retrieval_embedding, salient_detection, | |||
| super_resolution, video_single_object_tracking, | |||
| video_summarization, virual_tryon) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .panseg_model import SwinLPanopticSegmentation | |||
| else: | |||
| _import_structure = { | |||
| 'panseg_model': ['SwinLPanopticSegmentation'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,54 @@ | |||
| import os.path as osp | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @MODELS.register_module( | |||
| Tasks.image_segmentation, module_name=Models.panoptic_segmentation) | |||
| class SwinLPanopticSegmentation(TorchModel): | |||
| def __init__(self, model_dir: str, **kwargs): | |||
| """str -- model file root.""" | |||
| super().__init__(model_dir, **kwargs) | |||
| from mmcv.runner import load_checkpoint | |||
| import mmcv | |||
| from mmdet.models import build_detector | |||
| config = osp.join(model_dir, 'config.py') | |||
| cfg = mmcv.Config.fromfile(config) | |||
| if 'pretrained' in cfg.model: | |||
| cfg.model.pretrained = None | |||
| elif 'init_cfg' in cfg.model.backbone: | |||
| cfg.model.backbone.init_cfg = None | |||
| # build model | |||
| cfg.model.train_cfg = None | |||
| self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |||
| # load model | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| checkpoint = load_checkpoint( | |||
| self.model, model_path, map_location='cpu') | |||
| self.CLASSES = checkpoint['meta']['CLASSES'] | |||
| self.num_classes = len(self.CLASSES) | |||
| self.cfg = cfg | |||
| def inference(self, data): | |||
| """data is dict,contain img and img_metas,follow with mmdet.""" | |||
| with torch.no_grad(): | |||
| results = self.model(return_loss=False, rescale=True, **data) | |||
| return results | |||
| def forward(self, Inputs): | |||
| import pdb | |||
| pdb.set_trace() | |||
| return self.model(**Inputs) | |||
| @@ -23,6 +23,7 @@ if TYPE_CHECKING: | |||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline | |||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
| from .image_reid_person_pipeline import ImageReidPersonPipeline | |||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
| @@ -37,6 +38,7 @@ if TYPE_CHECKING: | |||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
| else: | |||
| _import_structure = { | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| @@ -59,6 +61,8 @@ else: | |||
| 'image_instance_segmentation_pipeline': | |||
| ['ImageInstanceSegmentationPipeline'], | |||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | |||
| 'image_panoptic_segmentation_pipeline': | |||
| ['ImagePanopticSegmentationPipeline'], | |||
| 'image_portrait_enhancement_pipeline': | |||
| ['ImagePortraitEnhancementPipeline'], | |||
| 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | |||
| @@ -0,0 +1,103 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_segmentation, | |||
| module_name=Pipelines.image_panoptic_segmentation) | |||
| class ImagePanopticSegmentationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a image panoptic segmentation pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('panoptic segmentation model, pipeline init') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| from mmdet.datasets.pipelines import Compose | |||
| from mmcv.parallel import collate, scatter | |||
| from mmdet.datasets import replace_ImageToTensor | |||
| cfg = self.model.cfg | |||
| # build the data pipeline | |||
| if isinstance(input, str): | |||
| # input is str, file names, pipeline loadimagefromfile | |||
| # collect data | |||
| data = dict(img_info=dict(filename=input), img_prefix=None) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
| img = np.array(input.convert('RGB')) | |||
| # collect data | |||
| data = dict(img=img) | |||
| elif isinstance(input, np.ndarray): | |||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| else: | |||
| img = input | |||
| img = img[:, :, ::-1] # in rgb order | |||
| # collect data | |||
| data = dict(img=img) | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |||
| test_pipeline = Compose(cfg.data.test.pipeline) | |||
| data = test_pipeline(data) | |||
| # copy from mmdet_model collect data | |||
| data = collate([data], samples_per_gpu=1) | |||
| data['img_metas'] = [ | |||
| img_metas.data[0] for img_metas in data['img_metas'] | |||
| ] | |||
| data['img'] = [img.data[0] for img in data['img']] | |||
| if next(self.model.parameters()).is_cuda: | |||
| # scatter to specified GPU | |||
| data = scatter(data, [next(self.model.parameters()).device])[0] | |||
| return data | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| results = self.model.inference(input) | |||
| return results | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| # bz=1, tcguo | |||
| pan_results = inputs[0]['pan_results'] | |||
| INSTANCE_OFFSET = 1000 | |||
| ids = np.unique(pan_results)[::-1] | |||
| legal_indices = ids != self.model.num_classes # for VOID label | |||
| ids = ids[legal_indices] | |||
| labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |||
| segms = (pan_results[None] == ids[:, None, None]) | |||
| masks = [it.astype(np.int) for it in segms] | |||
| labels_txt = np.array(self.model.CLASSES)[labels].tolist() | |||
| outputs = { | |||
| OutputKeys.MASKS: masks, | |||
| OutputKeys.LABELS: labels_txt, | |||
| OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] | |||
| } | |||
| return outputs | |||
| @@ -134,3 +134,22 @@ def show_video_tracking_result(video_in_path, bboxes, video_save_path): | |||
| video_writer.write(frame) | |||
| video_writer.release | |||
| cap.release() | |||
| def panoptic_seg_masks_to_image(masks): | |||
| draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) | |||
| from mmdet.core.visualization.palette import get_palette | |||
| mask_palette = get_palette('coco', 133) | |||
| from mmdet.core.visualization.image import _get_bias_color | |||
| taken_colors = set([0, 0, 0]) | |||
| for i, mask in enumerate(masks): | |||
| color_mask = mask_palette[i] | |||
| while tuple(color_mask) in taken_colors: | |||
| color_mask = _get_bias_color(color_mask) | |||
| taken_colors.add(tuple(color_mask)) | |||
| mask = mask.astype(bool) | |||
| draw_img[mask] = color_mask | |||
| return draw_img | |||
| @@ -0,0 +1,40 @@ | |||
| import unittest | |||
| import cv2 | |||
| import PIL | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import panoptic_seg_masks_to_image | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImagePanopticSegmentationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_image_panoptic_segmentation(self): | |||
| input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||
| model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||
| pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||
| result = pan_segmentor(input_location) | |||
| draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('print test_image_panoptic_segmentation return success') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_image_panoptic_segmentation_from_PIL(self): | |||
| input_location = 'data/test/images/image_panoptic_segmentation.jpg' | |||
| model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' | |||
| pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id) | |||
| PIL_array = PIL.Image.open(input_location) | |||
| result = pan_segmentor(PIL_array) | |||
| draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('print test_image_panoptic_segmentation from PIL return success') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||