tianchu.gtc yingda.chen 3 years ago
parent
commit
bd4127bc27
2 changed files with 41 additions and 15 deletions
  1. +24
    -0
      modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py
  2. +17
    -15
      tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py

+ 24
- 0
modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py View File

@@ -1,5 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from .base import EasyCVPipeline
@@ -21,3 +26,22 @@ class EasyCVSegmentationPipeline(EasyCVPipeline):
model_file_pattern=model_file_pattern,
*args,
**kwargs)

def __call__(self, inputs) -> Any:
outputs = self.predict_op(inputs)

semantic_result = outputs[0]['seg_pred']

ids = np.unique(semantic_result)[::-1]
legal_indices = ids != len(self.predict_op.CLASSES) # for VOID label
ids = ids[legal_indices]
segms = (semantic_result[None] == ids[:, None, None])
masks = [it.astype(np.int) for it in segms]
labels_txt = np.array(self.predict_op.CLASSES)[ids].tolist()

results = {
OutputKeys.MASKS: masks,
OutputKeys.LABELS: labels_txt,
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))]
}
return results

+ 17
- 15
tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py View File

@@ -2,30 +2,34 @@
import unittest
from distutils.version import LooseVersion

import cv2
import easycv
import numpy as np
from PIL import Image

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class EasyCVSegmentationPipelineTest(unittest.TestCase):
class EasyCVSegmentationPipelineTest(unittest.TestCase,
DemoCompatibilityCheck):
img_path = 'data/test/images/image_segmentation.jpg'

def _internal_test_(self, model_id):
img = np.asarray(Image.open(self.img_path))
def setUp(self) -> None:
self.task = Tasks.image_segmentation
self.model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'

def _internal_test_(self, model_id):
semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id)
outputs = semantic_seg(self.img_path)

self.assertEqual(len(outputs), 1)

results = outputs[0]
self.assertListEqual(
list(img.shape)[:2], list(results['seg_pred'].shape))
draw_img = semantic_seg_masks_to_image(outputs[OutputKeys.MASKS])
cv2.imwrite('result.jpg', draw_img)
print('test ' + model_id + ' DONE')

def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2):
# TODO: support in the future
@@ -49,37 +53,35 @@ class EasyCVSegmentationPipelineTest(unittest.TestCase):
def test_segformer_b0(self):
model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b1(self):
model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b2(self):
model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b3(self):
model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b4(self):
model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_segformer_b5(self):
model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k'
self._internal_test_(model_id)
self._internal_test_batch_(model_id)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':


Loading…
Cancel
Save