Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10265856master
| @@ -14,7 +14,8 @@ from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .models import U2NET | from .models import U2NET | ||||
| @MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection) | |||||
| @MODELS.register_module( | |||||
| Tasks.semantic_segmentation, module_name=Models.detection) | |||||
| class SalientDetection(TorchModel): | class SalientDetection(TorchModel): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -151,6 +151,12 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_segmentation: | Tasks.image_segmentation: | ||||
| [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS], | [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS], | ||||
| # semantic segmentation result for single sample | |||||
| # { | |||||
| # "masks": [np.array # 2D array containing only 0, 255] | |||||
| # } | |||||
| Tasks.semantic_segmentation: [OutputKeys.MASKS], | |||||
| # image matting result for single sample | # image matting result for single sample | ||||
| # { | # { | ||||
| # "output_img": np.array with shape(h, w, 4) | # "output_img": np.array with shape(h, w, 4) | ||||
| @@ -9,7 +9,7 @@ from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_segmentation, module_name=Pipelines.salient_detection) | |||||
| Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) | |||||
| class ImageSalientDetectionPipeline(Pipeline): | class ImageSalientDetectionPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| @@ -39,9 +39,5 @@ class ImageSalientDetectionPipeline(Pipeline): | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| data = self.model.postprocess(inputs) | data = self.model.postprocess(inputs) | ||||
| outputs = { | |||||
| OutputKeys.SCORES: None, | |||||
| OutputKeys.LABELS: None, | |||||
| OutputKeys.MASKS: data | |||||
| } | |||||
| outputs = {OutputKeys.MASKS: data} | |||||
| return outputs | return outputs | ||||
| @@ -38,6 +38,7 @@ class CVTasks(object): | |||||
| image_object_detection = 'image-object-detection' | image_object_detection = 'image-object-detection' | ||||
| image_segmentation = 'image-segmentation' | image_segmentation = 'image-segmentation' | ||||
| semantic_segmentation = 'semantic-segmentation' | |||||
| portrait_matting = 'portrait-matting' | portrait_matting = 'portrait-matting' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| @@ -11,17 +11,16 @@ from modelscope.utils.test_utils import test_level | |||||
| class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.task = Tasks.image_segmentation | |||||
| self.task = Tasks.semantic_segmentation | |||||
| self.model_id = 'damo/cv_u2net_salient-detection' | self.model_id = 'damo/cv_u2net_salient-detection' | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_salient_detection(self): | def test_salient_detection(self): | ||||
| input_location = 'data/test/images/image_salient_detection.jpg' | input_location = 'data/test/images/image_salient_detection.jpg' | ||||
| model_id = 'damo/cv_u2net_salient-detection' | model_id = 'damo/cv_u2net_salient-detection' | ||||
| salient_detect = pipeline(Tasks.image_segmentation, model=model_id) | |||||
| salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) | |||||
| result = salient_detect(input_location) | result = salient_detect(input_location) | ||||
| import cv2 | import cv2 | ||||
| # result[OutputKeys.MASKS] is salient map result,other keys are not used | |||||
| cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) | cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) | ||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||