From 26df8f198820c3c079e38c8fdb94c2fd4d836581 Mon Sep 17 00:00:00 2001 From: "wendi.hwd" Date: Tue, 27 Sep 2022 15:01:05 +0800 Subject: [PATCH] [to #42322933]add semantic-segmentation task output is numpy mask for demo-service Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10265856 --- modelscope/models/cv/salient_detection/salient_model.py | 3 ++- modelscope/outputs.py | 6 ++++++ .../pipelines/cv/image_salient_detection_pipeline.py | 8 ++------ modelscope/utils/constant.py | 1 + tests/pipelines/test_salient_detection.py | 5 ++--- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py index 6e617f58..73c3c3fb 100644 --- a/modelscope/models/cv/salient_detection/salient_model.py +++ b/modelscope/models/cv/salient_detection/salient_model.py @@ -14,7 +14,8 @@ from modelscope.utils.constant import ModelFile, Tasks from .models import U2NET -@MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection) +@MODELS.register_module( + Tasks.semantic_segmentation, module_name=Models.detection) class SalientDetection(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 052d4f33..b19f7e43 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -151,6 +151,12 @@ TASK_OUTPUTS = { Tasks.image_segmentation: [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS], + # semantic segmentation result for single sample + # { + # "masks": [np.array # 2D array containing only 0, 255] + # } + Tasks.semantic_segmentation: [OutputKeys.MASKS], + # image matting result for single sample # { # "output_img": np.array with shape(h, w, 4) diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py index 433275ba..3b145cf0 100644 --- a/modelscope/pipelines/cv/image_salient_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -9,7 +9,7 @@ from modelscope.utils.constant import Tasks @PIPELINES.register_module( - Tasks.image_segmentation, module_name=Pipelines.salient_detection) + Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) class ImageSalientDetectionPipeline(Pipeline): def __init__(self, model: str, **kwargs): @@ -39,9 +39,5 @@ class ImageSalientDetectionPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: data = self.model.postprocess(inputs) - outputs = { - OutputKeys.SCORES: None, - OutputKeys.LABELS: None, - OutputKeys.MASKS: data - } + outputs = {OutputKeys.MASKS: data} return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4c5d2f41..de3d933f 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -38,6 +38,7 @@ class CVTasks(object): image_object_detection = 'image-object-detection' image_segmentation = 'image-segmentation' + semantic_segmentation = 'semantic-segmentation' portrait_matting = 'portrait-matting' text_driven_segmentation = 'text-driven-segmentation' shop_segmentation = 'shop-segmentation' diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py index e87e9388..bcb904e6 100644 --- a/tests/pipelines/test_salient_detection.py +++ b/tests/pipelines/test_salient_detection.py @@ -11,17 +11,16 @@ from modelscope.utils.test_utils import test_level class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.task = Tasks.image_segmentation + self.task = Tasks.semantic_segmentation self.model_id = 'damo/cv_u2net_salient-detection' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_salient_detection(self): input_location = 'data/test/images/image_salient_detection.jpg' model_id = 'damo/cv_u2net_salient-detection' - salient_detect = pipeline(Tasks.image_segmentation, model=model_id) + salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) result = salient_detect(input_location) import cv2 - # result[OutputKeys.MASKS] is salient map result,other keys are not used cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) @unittest.skip('demo compatibility test is only enabled on a needed-basis')