From d3fac4f5be227ed6c7150df5e5f672277d04e669 Mon Sep 17 00:00:00 2001 From: "wendi.hwd" Date: Tue, 16 Aug 2022 09:15:53 +0800 Subject: [PATCH] [to #42322933] support salient detection Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9722903 --- data/test/images/image_salient_detection.jpg | 3 + modelscope/metainfo.py | 1 + modelscope/models/cv/__init__.py | 3 +- .../models/cv/object_detection/mmdet_model.py | 4 +- .../models/cv/salient_detection/__init__.py | 22 ++ .../cv/salient_detection/models/__init__.py | 1 + .../cv/salient_detection/models/u2net.py | 300 ++++++++++++++++++ .../cv/salient_detection/salient_model.py | 63 ++++ modelscope/pipelines/cv/__init__.py | 2 + .../cv/image_salient_detection_pipeline.py | 47 +++ tests/pipelines/test_salient_detection.py | 24 ++ 11 files changed, 467 insertions(+), 3 deletions(-) create mode 100644 data/test/images/image_salient_detection.jpg create mode 100644 modelscope/models/cv/salient_detection/__init__.py create mode 100644 modelscope/models/cv/salient_detection/models/__init__.py create mode 100644 modelscope/models/cv/salient_detection/models/u2net.py create mode 100644 modelscope/models/cv/salient_detection/salient_model.py create mode 100644 modelscope/pipelines/cv/image_salient_detection_pipeline.py create mode 100644 tests/pipelines/test_salient_detection.py diff --git a/data/test/images/image_salient_detection.jpg b/data/test/images/image_salient_detection.jpg new file mode 100644 index 00000000..9c0632d3 --- /dev/null +++ b/data/test/images/image_salient_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70ea0c06f9cfe3882253f7175221d47e394ab9c469076ab220e880b17dbcdd02 +size 48552 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a0aab6d3..54109571 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -86,6 +86,7 @@ class Pipelines(object): body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' human_detection = 'resnet18-human-detection' object_detection = 'vit-object-detection' + salient_detection = 'u2net-salient-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' live_category = 'live-category' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index a05bc57d..2a790ffd 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -5,4 +5,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_colorization, image_denoise, image_instance_segmentation, image_portrait_enhancement, image_to_image_generation, image_to_image_translation, object_detection, - product_retrieval_embedding, super_resolution, virual_tryon) + product_retrieval_embedding, salient_detection, + super_resolution, virual_tryon) diff --git a/modelscope/models/cv/object_detection/mmdet_model.py b/modelscope/models/cv/object_detection/mmdet_model.py index 51f05e47..7bf81349 100644 --- a/modelscope/models/cv/object_detection/mmdet_model.py +++ b/modelscope/models/cv/object_detection/mmdet_model.py @@ -38,7 +38,7 @@ class DetectionModel(TorchModel): self.model, model_path, map_location='cpu') self.class_names = checkpoint['meta']['CLASSES'] config.test_pipeline[0].type = 'LoadImageFromWebcam' - self.test_pipeline = Compose( + self.transform_input = Compose( replace_ImageToTensor(config.test_pipeline)) self.model.cfg = config self.model.eval() @@ -56,7 +56,7 @@ class DetectionModel(TorchModel): from mmcv.parallel import collate, scatter data = dict(img=image) - data = self.test_pipeline(data) + data = self.transform_input(data) data = collate([data], samples_per_gpu=1) data['img_metas'] = [ img_metas.data[0] for img_metas in data['img_metas'] diff --git a/modelscope/models/cv/salient_detection/__init__.py b/modelscope/models/cv/salient_detection/__init__.py new file mode 100644 index 00000000..b3b5b5fa --- /dev/null +++ b/modelscope/models/cv/salient_detection/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .salient_model import SalientDetection + +else: + _import_structure = { + 'salient_model': ['SalientDetection'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/salient_detection/models/__init__.py b/modelscope/models/cv/salient_detection/models/__init__.py new file mode 100644 index 00000000..0850c33d --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/__init__.py @@ -0,0 +1 @@ +from .u2net import U2NET diff --git a/modelscope/models/cv/salient_detection/models/u2net.py b/modelscope/models/cv/salient_detection/models/u2net.py new file mode 100644 index 00000000..0a0a4511 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/u2net.py @@ -0,0 +1,300 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/xuebinqin/U-2-Net +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class REBNCONV(nn.Module): + + def __init__(self, in_ch=3, out_ch=3, dirate=1): + super(REBNCONV, self).__init__() + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + return xout + + +def _upsample_like(src, tar): + """upsample tensor 'src' to have the same spatial size with tensor 'tar'.""" + src = F.upsample(src, size=tar.shape[2:], mode='bilinear') + return src + + +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU7, self).__init__() + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + hx6 = self.rebnconv6(hx) + hx7 = self.rebnconv7(hx6) + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + hx5 = self.rebnconv5(hx) + hx6 = self.rebnconv6(hx5) + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx5 = self.rebnconv5(hx4) + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx4 = self.rebnconv4(hx3) + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + hx4 = self.rebnconv4(hx3) + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + return hx1d + hxin + + +class U2NET(nn.Module): + + def __init__(self, in_ch=3, out_ch=1): + super(U2NET, self).__init__() + + # encoder + self.stage1 = RSU7(in_ch, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage6 = RSU4F(512, 256, 512) + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) + + def forward(self, x): + + hx = x + hx1 = self.stage1(hx) + hx = self.pool12(hx1) + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + d1 = self.side1(hx1d) + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, d1) + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, d1) + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, d1) + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, d1) + d6 = self.side6(hx6) + d6 = _upsample_like(d6, d1) + d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) + return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid( + d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid( + d5), torch.sigmoid(d6) diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py new file mode 100644 index 00000000..539d1f24 --- /dev/null +++ b/modelscope/models/cv/salient_detection/salient_model.py @@ -0,0 +1,63 @@ +import os.path as osp + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .models import U2NET + + +@MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection) +class SalientDetection(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + self.model = U2NET(3, 1) + checkpoint = torch.load(model_path, map_location='cpu') + self.transform_input = transforms.Compose([ + transforms.Resize((320, 320)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.model.load_state_dict(checkpoint) + self.model.eval() + + def inference(self, data): + """data is tensor 3 * H * W ---> return tensor H * W .""" + data = data.unsqueeze(0) + if next(self.model.parameters()).is_cuda: + data = data.to( + torch.device([next(self.model.parameters()).device][0])) + + with torch.no_grad(): + results = self.model(data) + + if next(self.model.parameters()).is_cuda: + return results[0][0, 0, :, :].cpu() + return results[0][0, 0, :, :] + + def preprocess(self, image): + """image is numpy.""" + data = self.transform_input(Image.fromarray(image)) + return data.float() + + def postprocess(self, inputs): + """resize .""" + data = inputs['data'] + w = inputs['img_w'] + h = inputs['img_h'] + data_norm = (data - torch.min(data)) / ( + torch.max(data) - torch.min(data)) + data_norm_np = (data_norm.numpy() * 255).astype('uint8') + data_norm_rst = cv2.resize(data_norm_np, (w, h)) + + return data_norm_rst diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 91a2f1e0..cee91c8e 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline from .crowd_counting_pipeline import CrowdCountingPipeline from .image_detection_pipeline import ImageDetectionPipeline + from .image_salient_detection_pipeline import ImageSalientDetectionPipeline from .face_detection_pipeline import FaceDetectionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline from .face_recognition_pipeline import FaceRecognitionPipeline @@ -43,6 +44,7 @@ else: 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], 'crowd_counting_pipeline': ['CrowdCountingPipeline'], 'image_detection_pipeline': ['ImageDetectionPipeline'], + 'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'], 'face_detection_pipeline': ['FaceDetectionPipeline'], 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], 'face_recognition_pipeline': ['FaceRecognitionPipeline'], diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py new file mode 100644 index 00000000..433275ba --- /dev/null +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -0,0 +1,47 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.image_segmentation, module_name=Pipelines.salient_detection) +class ImageSalientDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + img_h, img_w, _ = img.shape + img = self.model.preprocess(img) + result = {'img': img, 'img_w': img_w, 'img_h': img_h} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = { + 'data': outputs, + 'img_w': input['img_w'], + 'img_h': input['img_h'] + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + data = self.model.postprocess(inputs) + outputs = { + OutputKeys.SCORES: None, + OutputKeys.LABELS: None, + OutputKeys.MASKS: data + } + return outputs diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py new file mode 100644 index 00000000..ec010b17 --- /dev/null +++ b/tests/pipelines/test_salient_detection.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class SalientDetectionTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_salient_detection(self): + input_location = 'data/test/images/image_salient_detection.jpg' + model_id = 'damo/cv_u2net_salient-detection' + salient_detect = pipeline(Tasks.image_segmentation, model=model_id) + result = salient_detect(input_location) + import cv2 + # result[OutputKeys.MASKS] is salient map result,other keys are not used + cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) + + +if __name__ == '__main__': + unittest.main()