tianchu.gtc yingda.chen 3 years ago
parent
commit
bd4127bc27
2 changed files with 41 additions and 15 deletions
  1. +24
    -0
      modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py
  2. +17
    -15
      tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py

+ 24
- 0
modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py View File

@@ -1,5 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any

import numpy as np

from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from .base import EasyCVPipeline from .base import EasyCVPipeline
@@ -21,3 +26,22 @@ class EasyCVSegmentationPipeline(EasyCVPipeline):
model_file_pattern=model_file_pattern, model_file_pattern=model_file_pattern,
*args, *args,
**kwargs) **kwargs)

def __call__(self, inputs) -> Any:
outputs = self.predict_op(inputs)

semantic_result = outputs[0]['seg_pred']

ids = np.unique(semantic_result)[::-1]
legal_indices = ids != len(self.predict_op.CLASSES) # for VOID label
ids = ids[legal_indices]
segms = (semantic_result[None] == ids[:, None, None])
masks = [it.astype(np.int) for it in segms]
labels_txt = np.array(self.predict_op.CLASSES)[ids].tolist()

results = {
OutputKeys.MASKS: masks,
OutputKeys.LABELS: labels_txt,
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))]
}
return results

+ 17
- 15
tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py View File

@@ -2,30 +2,34 @@
import unittest import unittest
from distutils.version import LooseVersion from distutils.version import LooseVersion


import cv2
import easycv import easycv
import numpy as np import numpy as np
from PIL import Image from PIL import Image


from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




class EasyCVSegmentationPipelineTest(unittest.TestCase):
class EasyCVSegmentationPipelineTest(unittest.TestCase,
DemoCompatibilityCheck):
img_path = 'data/test/images/image_segmentation.jpg' img_path = 'data/test/images/image_segmentation.jpg'


def _internal_test_(self, model_id):
img = np.asarray(Image.open(self.img_path))
def setUp(self) -> None:
self.task = Tasks.image_segmentation
self.model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'


def _internal_test_(self, model_id):
semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id) semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id)
outputs = semantic_seg(self.img_path) outputs = semantic_seg(self.img_path)


self.assertEqual(len(outputs), 1)

results = outputs[0]
self.assertListEqual(
list(img.shape)[:2], list(results['seg_pred'].shape))
draw_img = semantic_seg_masks_to_image(outputs[OutputKeys.MASKS])
cv2.imwrite('result.jpg', draw_img)
print('test ' + model_id + ' DONE')


def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2): def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2):
# TODO: support in the future # TODO: support in the future
@@ -49,37 +53,35 @@ class EasyCVSegmentationPipelineTest(unittest.TestCase):
def test_segformer_b0(self): def test_segformer_b0(self):
model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b1(self): def test_segformer_b1(self):
model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b2(self): def test_segformer_b2(self):
model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b3(self): def test_segformer_b3(self):
model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b4(self): def test_segformer_b4(self):
model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b5(self): def test_segformer_b5(self):
model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k' model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id) self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self):
self.compatibility_check()




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save