From 674e1a7878f63603aa3bbc669fbac6a8b8a5b8a5 Mon Sep 17 00:00:00 2001 From: "wendi.hwd" Date: Mon, 17 Oct 2022 14:06:07 +0800 Subject: [PATCH] [to #42322933]cv/cvdet_fix_outputs->master fix outputs Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421413 * fix outputs --- .../pipelines/cv/image_detection_pipeline.py | 8 ++++++-- tests/pipelines/test_object_detection.py | 20 ++++--------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/modelscope/pipelines/cv/image_detection_pipeline.py b/modelscope/pipelines/cv/image_detection_pipeline.py index f5554ca2..08633c35 100644 --- a/modelscope/pipelines/cv/image_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_detection_pipeline.py @@ -43,11 +43,15 @@ class ImageDetectionPipeline(Pipeline): bboxes, scores, labels = self.model.postprocess(inputs['data']) if bboxes is None: - return None + outputs = { + OutputKeys.SCORES: [], + OutputKeys.LABELS: [], + OutputKeys.BOXES: [] + } + return outputs outputs = { OutputKeys.SCORES: scores, OutputKeys.LABELS: labels, OutputKeys.BOXES: bboxes } - return outputs diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py index 2cb217d9..00a71371 100644 --- a/tests/pipelines/test_object_detection.py +++ b/tests/pipelines/test_object_detection.py @@ -19,20 +19,14 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): model_id = 'damo/cv_vit_object-detection_coco' object_detect = pipeline(Tasks.image_object_detection, model=model_id) result = object_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_object_detection_with_default_task(self): input_location = 'data/test/images/image_detection.jpg' object_detect = pipeline(Tasks.image_object_detection) result = object_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_human_detection(self): @@ -40,20 +34,14 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): model_id = 'damo/cv_resnet18_human-detection' human_detect = pipeline(Tasks.human_detection, model=model_id) result = human_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_human_detection_with_default_task(self): input_location = 'data/test/images/image_detection.jpg' human_detect = pipeline(Tasks.human_detection) result = human_detect(input_location) - if result: - print(result) - else: - raise ValueError('process error') + print(result) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self):