diff --git a/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py index 2182e3b3..bd09fc9b 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py @@ -1,5 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + +import numpy as np + from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks from .base import EasyCVPipeline @@ -21,3 +26,22 @@ class EasyCVSegmentationPipeline(EasyCVPipeline): model_file_pattern=model_file_pattern, *args, **kwargs) + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + semantic_result = outputs[0]['seg_pred'] + + ids = np.unique(semantic_result)[::-1] + legal_indices = ids != len(self.predict_op.CLASSES) # for VOID label + ids = ids[legal_indices] + segms = (semantic_result[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.predict_op.CLASSES)[ids].tolist() + + results = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return results diff --git a/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py b/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py index 80ab36a6..5f6dac4b 100644 --- a/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py +++ b/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py @@ -2,30 +2,34 @@ import unittest from distutils.version import LooseVersion +import cv2 import easycv import numpy as np from PIL import Image +from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image +from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level -class EasyCVSegmentationPipelineTest(unittest.TestCase): - +class EasyCVSegmentationPipelineTest(unittest.TestCase, + DemoCompatibilityCheck): img_path = 'data/test/images/image_segmentation.jpg' - def _internal_test_(self, model_id): - img = np.asarray(Image.open(self.img_path)) + def setUp(self) -> None: + self.task = Tasks.image_segmentation + self.model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' + def _internal_test_(self, model_id): semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id) outputs = semantic_seg(self.img_path) - self.assertEqual(len(outputs), 1) - - results = outputs[0] - self.assertListEqual( - list(img.shape)[:2], list(results['seg_pred'].shape)) + draw_img = semantic_seg_masks_to_image(outputs[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test ' + model_id + ' DONE') def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2): # TODO: support in the future @@ -49,37 +53,35 @@ class EasyCVSegmentationPipelineTest(unittest.TestCase): def test_segformer_b0(self): model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_segformer_b1(self): model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_segformer_b2(self): model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_segformer_b3(self): model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_segformer_b4(self): model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_segformer_b5(self): model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k' self._internal_test_(model_id) - self._internal_test_batch_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() if __name__ == '__main__':