Browse Source

update pipeline according to online demo requirements

根据在线demo前端的要求,多输出一个color图片用于展示
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10926624
master^2
qianmu.ywh yingda.chen 3 years ago
parent
commit
bca6da3b56
3 changed files with 9 additions and 4 deletions
  1. +1
    -0
      modelscope/outputs/outputs.py
  2. +6
    -1
      modelscope/pipelines/cv/image_depth_estimation_pipeline.py
  3. +2
    -3
      tests/pipelines/test_image_depth_estimation.py

+ 1
- 0
modelscope/outputs/outputs.py View File

@@ -20,6 +20,7 @@ class OutputKeys(object):
KEYPOINTS = 'keypoints' KEYPOINTS = 'keypoints'
MASKS = 'masks' MASKS = 'masks'
DEPTHS = 'depths' DEPTHS = 'depths'
DEPTHS_COLOR = 'depths_color'
TEXT = 'text' TEXT = 'text'
POLYGONS = 'polygons' POLYGONS = 'polygons'
OUTPUT = 'output' OUTPUT = 'output'


+ 6
- 1
modelscope/pipelines/cv/image_depth_estimation_pipeline.py View File

@@ -12,6 +12,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import depth_to_color
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


logger = get_logger() logger = get_logger()
@@ -50,6 +51,10 @@ class ImageDepthEstimationPipeline(Pipeline):
depths = results[OutputKeys.DEPTHS] depths = results[OutputKeys.DEPTHS]
if isinstance(depths, torch.Tensor): if isinstance(depths, torch.Tensor):
depths = depths.detach().cpu().squeeze().numpy() depths = depths.detach().cpu().squeeze().numpy()
outputs = {OutputKeys.DEPTHS: depths}
depths_color = depth_to_color(depths)
outputs = {
OutputKeys.DEPTHS: depths,
OutputKeys.DEPTHS_COLOR: depths_color
}


return outputs return outputs

+ 2
- 3
tests/pipelines/test_image_depth_estimation.py View File

@@ -24,9 +24,8 @@ class ImageDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck):
input_location = 'data/test/images/image_depth_estimation.jpg' input_location = 'data/test/images/image_depth_estimation.jpg'
estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id) estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id)
result = estimator(input_location) result = estimator(input_location)
depths = result[OutputKeys.DEPTHS]
depth_viz = depth_to_color(depths)
cv2.imwrite('result.jpg', depth_viz)
depth_vis = result[OutputKeys.DEPTHS_COLOR]
cv2.imwrite('result.jpg', depth_vis)


print('test_image_depth_estimation DONE') print('test_image_depth_estimation DONE')




Loading…
Cancel
Save