wendi.hwd yingda.chen 3 years ago
parent
commit
ff55bd9436
11 changed files with 600 additions and 6 deletions
  1. +3
    -0
      data/test/images/image_camouflag_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +1
    -0
      modelscope/models/cv/salient_detection/models/__init__.py
  4. +187
    -0
      modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py
  5. +6
    -0
      modelscope/models/cv/salient_detection/models/backbone/__init__.py
  6. +178
    -0
      modelscope/models/cv/salient_detection/models/modules.py
  7. +74
    -0
      modelscope/models/cv/salient_detection/models/senet.py
  8. +105
    -0
      modelscope/models/cv/salient_detection/models/utils.py
  9. +18
    -6
      modelscope/models/cv/salient_detection/salient_model.py
  10. +5
    -0
      modelscope/pipelines/cv/image_salient_detection_pipeline.py
  11. +21
    -0
      tests/pipelines/test_salient_detection.py

+ 3
- 0
data/test/images/image_camouflag_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f
size 19445

+ 2
- 0
modelscope/metainfo.py View File

@@ -165,6 +165,8 @@ class Pipelines(object):
easycv_segmentation = 'easycv-segmentation' easycv_segmentation = 'easycv-segmentation'
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment' face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
salient_detection = 'u2net-salient-detection' salient_detection = 'u2net-salient-detection'
salient_boudary_detection = 'res2net-salient-detection'
camouflaged_detection = 'res2net-camouflaged-detection'
image_classification = 'image-classification' image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps' face_detection = 'resnet-face-detection-scrfd10gkps'
card_detection = 'resnet-card-detection-scrfd34gkps' card_detection = 'resnet-card-detection-scrfd34gkps'


+ 1
- 0
modelscope/models/cv/salient_detection/models/__init__.py View File

@@ -1,3 +1,4 @@
# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License # The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License
# source code avaiable via https://github.com/xuebinqin/U-2-Net # source code avaiable via https://github.com/xuebinqin/U-2-Net
from .senet import SENet
from .u2net import U2NET from .u2net import U2NET

+ 187
- 0
modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py View File

@@ -0,0 +1,187 @@
# Implementation in this file is modified based on Res2Net-PretrainedModels
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
import math

import torch
import torch.nn as nn

__all__ = ['Res2Net', 'res2net50_v1b_26w_4s']


class Bottle2neck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
baseWidth=26,
scale=4,
stype='normal'):
""" Constructor
Args:
inplanes: input channel dimensionality
planes: output channel dimensionality
stride: conv stride. Replaces pooling layer.
downsample: None when stride = 1
baseWidth: basic width of conv3x3
scale: number of scale.
type: 'normal': normal set. 'stage': first block of a new stage.
"""
super(Bottle2neck, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(
inplanes, width * scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
if scale == 1:
self.nums = 1
else:
self.nums = scale - 1
if stype == 'stage':
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
convs = []
bns = []
for i in range(self.nums):
convs.append(
nn.Conv2d(
width,
width,
kernel_size=3,
stride=stride,
padding=1,
bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv2d(
width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stype = stype
self.scale = scale
self.width = width

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0 or self.stype == 'stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
if self.scale != 1 and self.stype == 'normal':
out = torch.cat((out, spx[self.nums]), 1)
elif self.scale != 1 and self.stype == 'stage':
out = torch.cat((out, self.pool(spx[self.nums])), 1)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out


class Res2Net(nn.Module):

def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
self.inplanes = 64
super(Res2Net, self).__init__()
self.baseWidth = baseWidth
self.scale = scale
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1, bias=False), nn.BatchNorm2d(32),
nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 1, 1, bias=False))
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False),
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample=downsample,
stype='stage',
baseWidth=self.baseWidth,
scale=self.scale))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
baseWidth=self.baseWidth,
scale=self.scale))
return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


def res2net50_v1b_26w_4s(backbone_path, pretrained=False, **kwargs):
"""Constructs a Res2Net-50_v1b_26w_4s lib.
Args:
pretrained (bool): If True, returns a lib pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
if pretrained:
model_state = torch.load(backbone_path)
model.load_state_dict(model_state)
return model

+ 6
- 0
modelscope/models/cv/salient_detection/models/backbone/__init__.py View File

@@ -0,0 +1,6 @@
# Implementation in this file is modified based on Res2Net-PretrainedModels
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
from .Res2Net_v1b import res2net50_v1b_26w_4s

__all__ = ['res2net50_v1b_26w_4s']

+ 178
- 0
modelscope/models/cv/salient_detection/models/modules.py View File

@@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import ConvBNReLU


class AreaLayer(nn.Module):

def __init__(self, in_channel, out_channel):
super(AreaLayer, self).__init__()
self.lbody = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
self.hbody = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True))
self.body = nn.Sequential(
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, 1, 1))

def forward(self, xl, xh):
xl1 = self.lbody(xl)
xl1 = F.interpolate(
xl1, size=xh.size()[2:], mode='bilinear', align_corners=True)
xh1 = self.hbody(xh)
x = torch.cat((xl1, xh1), dim=1)
x_out = self.body(x)
return x_out


class EdgeLayer(nn.Module):

def __init__(self, in_channel, out_channel):
super(EdgeLayer, self).__init__()
self.lbody = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
self.hbody = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True))
self.bodye = nn.Sequential(
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, 1, 1))

def forward(self, xl, xh):
xl1 = self.lbody(xl)
xh1 = self.hbody(xh)
xh1 = F.interpolate(
xh1, size=xl.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((xl1, xh1), dim=1)
x_out = self.bodye(x)
return x_out


class EBlock(nn.Module):

def __init__(self, inchs, outchs):
super(EBlock, self).__init__()
self.elayer = nn.Sequential(
ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1),
ConvBNReLU(outchs, outchs, 1))
self.salayer = nn.Sequential(
nn.Conv2d(2, 1, 3, 1, 1, bias=False),
nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid())

def forward(self, x, edgeAtten):
x = torch.cat((x, edgeAtten), dim=1)
ex = self.elayer(x)
ex_max = torch.max(ex, 1, keepdim=True)[0]
ex_mean = torch.mean(ex, dim=1, keepdim=True)
xei_compress = torch.cat((ex_max, ex_mean), dim=1)

scale = self.salayer(xei_compress)
x_out = ex * scale
return x_out


class StructureE(nn.Module):

def __init__(self, inchs, outchs, EM):
super(StructureE, self).__init__()
self.ne_modules = int(inchs / EM)
NM = int(outchs / self.ne_modules)
elayes = []
for i in range(self.ne_modules):
emblock = EBlock(EM, NM)
elayes.append(emblock)
self.emlayes = nn.ModuleList(elayes)
self.body = nn.Sequential(
ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1))

def forward(self, x, edgeAtten):
if edgeAtten.size() != x.size():
edgeAtten = F.interpolate(
edgeAtten, x.size()[2:], mode='bilinear', align_corners=False)
xx = torch.chunk(x, self.ne_modules, dim=1)
efeas = []
for i in range(self.ne_modules):
xei = self.emlayes[i](xx[i], edgeAtten)
efeas.append(xei)
efeas = torch.cat(efeas, dim=1)
x_out = self.body(efeas)
return x_out


class ABlock(nn.Module):

def __init__(self, inchs, outchs, k):
super(ABlock, self).__init__()
self.alayer = nn.Sequential(
ConvBNReLU(inchs, outchs, k, 1, k // 2),
ConvBNReLU(outchs, outchs, 1))
self.arlayer = nn.Sequential(
ConvBNReLU(inchs, outchs, k, 1, k // 2),
ConvBNReLU(outchs, outchs, 1))
self.fusion = ConvBNReLU(2 * outchs, outchs, 1)

def forward(self, x, areaAtten):
xa = x * areaAtten
xra = x * (1 - areaAtten)
xout = self.fusion(torch.cat((xa, xra), dim=1))
return xout


class AMFusion(nn.Module):

def __init__(self, inchs, outchs, AM):
super(AMFusion, self).__init__()
self.k = [3, 3, 5, 5]
self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
self.na_modules = int(outchs / AM)
alayers = []
for i in range(self.na_modules):
layer = ABlock(AM, AM, self.k[i])
alayers.append(layer)
self.alayers = nn.ModuleList(alayers)
self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1)
self.fusion_e = nn.Sequential(
nn.Conv2d(
outchs, outchs, kernel_size=(3, 1), padding=(1, 0),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
nn.Conv2d(
outchs, outchs, kernel_size=(1, 3), padding=(0, 1),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
self.fusion_e1 = nn.Sequential(
nn.Conv2d(
outchs, outchs, kernel_size=(5, 1), padding=(2, 0),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
nn.Conv2d(
outchs, outchs, kernel_size=(1, 5), padding=(0, 2),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
self.fusion = ConvBNReLU(3 * outchs, outchs, 1)

def forward(self, xl, xh, xhm):
xh1 = self.up(self.conv_up(xh))
x = xh1 + xl
xm = self.up(torch.sigmoid(xhm))
xx = torch.chunk(x, self.na_modules, dim=1)
xxmids = []
for i in range(self.na_modules):
xi = self.alayers[i](xx[i], xm)
xxmids.append(xi)
xfea = torch.cat(xxmids, dim=1)
x0 = self.fusion_0(xfea)
x1 = self.fusion_e(xfea)
x2 = self.fusion_e1(xfea)
x_out = self.fusion(torch.cat((x0, x1, x2), dim=1))
return x_out

+ 74
- 0
modelscope/models/cv/salient_detection/models/senet.py View File

@@ -0,0 +1,74 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbone import res2net50_v1b_26w_4s as res2net
from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE
from .utils import ASPP, CBAM, ConvBNReLU


class SENet(nn.Module):

def __init__(self, backbone_path=None, pretrained=False):
super(SENet, self).__init__()
resnet50 = res2net(backbone_path, pretrained)
self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1,
resnet50.relu)
self.maxpool = resnet50.maxpool
self.layer1 = resnet50.layer1
self.layer2 = resnet50.layer2
self.layer3 = resnet50.layer3
self.layer4 = resnet50.layer4
self.aspp3 = ASPP(1024, 256)
self.aspp4 = ASPP(2048, 256)
self.cbblock3 = CBAM(inchs=256, kernel_size=5)
self.cbblock4 = CBAM(inchs=256, kernel_size=5)
self.up = nn.Upsample(
mode='bilinear', scale_factor=2, align_corners=False)
self.conv_up = ConvBNReLU(512, 512, 1)
self.aux_edge = EdgeLayer(512, 256)
self.aux_area = AreaLayer(512, 256)
self.layer1_enhance = StructureE(256, 128, 128)
self.layer2_enhance = StructureE(512, 256, 128)
self.layer3_decoder = AMFusion(512, 256, 128)
self.layer2_decoder = AMFusion(256, 128, 128)
self.out_conv_8 = nn.Conv2d(256, 1, 1)
self.out_conv_4 = nn.Conv2d(128, 1, 1)

def forward(self, x):
layer0 = self.layer0_1(x)
layer0s = self.maxpool(layer0)
layer1 = self.layer1(layer0s)
layer2 = self.layer2(layer1)
layer3 = self.layer3(layer2)
layer4 = self.layer4(layer3)
layer3_eh = self.cbblock3(self.aspp3(layer3))
layer4_eh = self.cbblock4(self.aspp4(layer4))
layer34 = self.conv_up(
torch.cat((self.up(layer4_eh), layer3_eh), dim=1))
edge_atten = self.aux_edge(layer1, layer34)
area_atten = self.aux_area(layer1, layer34)
edge_atten_ = torch.sigmoid(edge_atten)
layer1_eh = self.layer1_enhance(layer1, edge_atten_)
layer2_eh = self.layer2_enhance(layer2, edge_atten_)
layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten)
out_8 = self.out_conv_8(layer2_fu)
layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8)
out_4 = self.out_conv_4(layer1_fu)
out_16 = F.interpolate(
area_atten,
size=x.size()[2:],
mode='bilinear',
align_corners=False)
out_8 = F.interpolate(
out_8, size=x.size()[2:], mode='bilinear', align_corners=False)
out_4 = F.interpolate(
out_4, size=x.size()[2:], mode='bilinear', align_corners=False)
edge_out = F.interpolate(
edge_atten_,
size=x.size()[2:],
mode='bilinear',
align_corners=False)

return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out

+ 105
- 0
modelscope/models/cv/salient_detection/models/utils.py View File

@@ -0,0 +1,105 @@
# Implementation in this file is modified based on deeplabv3
# Originally MIT license,publicly avaialbe at https://github.com/fregu856/deeplabv3/blob/master/model/aspp.py
# Implementation in this file is modified based on attention-module
# Originally MIT license,publicly avaialbe at https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
import torch
import torch.nn as nn


class ConvBNReLU(nn.Module):

def __init__(self,
inplanes,
planes,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
bias=False):
super(ConvBNReLU, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(
inplanes,
planes,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias), nn.BatchNorm2d(planes), nn.ReLU(inplace=True))

def forward(self, x):
return self.block(x)


class ASPP(nn.Module):

def __init__(self, in_dim, out_dim):
super(ASPP, self).__init__()
mid_dim = 128
self.conv1 = ConvBNReLU(in_dim, mid_dim, kernel_size=1)
self.conv2 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=2, dilation=2)
self.conv3 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=5, dilation=5)
self.conv4 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=7, dilation=7)
self.conv5 = ConvBNReLU(in_dim, mid_dim, kernel_size=1, padding=0)
self.fuse = ConvBNReLU(5 * mid_dim, out_dim, 3, 1, 1)
self.global_pooling = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(x)
conv3 = self.conv3(x)
conv4 = self.conv4(x)
xg = self.conv5(self.global_pooling(x))
conv5 = nn.Upsample((x.shape[2], x.shape[3]), mode='nearest')(xg)
return self.fuse(torch.cat((conv1, conv2, conv3, conv4, conv5), 1))


class ChannelAttention(nn.Module):

def __init__(self, inchs, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(inchs, inchs // 16, 1, bias=False), nn.ReLU(),
nn.Conv2d(inchs // 16, inchs, 1, bias=False))
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)


class SpatialAttention(nn.Module):

def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()

self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)


class CBAM(nn.Module):

def __init__(self, inchs, kernel_size=7):
super().__init__()
self.calayer = ChannelAttention(inchs=inchs)
self.saLayer = SpatialAttention(kernel_size=kernel_size)

def forward(self, x):
xca = self.calayer(x) * x
xsa = self.saLayer(xca) * xca
return xsa

+ 18
- 6
modelscope/models/cv/salient_detection/salient_model.py View File

@@ -2,7 +2,6 @@
import os.path as osp import os.path as osp


import cv2 import cv2
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
@@ -10,8 +9,9 @@ from torchvision import transforms
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from .models import U2NET
from .models import U2NET, SENet




@MODELS.register_module( @MODELS.register_module(
@@ -22,13 +22,25 @@ class SalientDetection(TorchModel):
"""str -- model file root.""" """str -- model file root."""
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = U2NET(3, 1)

self.norm_mean = [0.485, 0.456, 0.406]
self.norm_std = [0.229, 0.224, 0.225]
self.norm_size = (320, 320)

config_path = osp.join(model_dir, 'config.py')
if osp.exists(config_path) is False:
self.model = U2NET(3, 1)
else:
self.model = SENet(backbone_path=None, pretrained=False)
config = Config.from_file(config_path)
self.norm_mean = config.norm_mean
self.norm_std = config.norm_std
self.norm_size = config.norm_size
checkpoint = torch.load(model_path, map_location='cpu') checkpoint = torch.load(model_path, map_location='cpu')
self.transform_input = transforms.Compose([ self.transform_input = transforms.Compose([
transforms.Resize((320, 320)),
transforms.Resize(self.norm_size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Normalize(mean=self.norm_mean, std=self.norm_std)
]) ])
self.model.load_state_dict(checkpoint) self.model.load_state_dict(checkpoint)
self.model.eval() self.model.eval()


+ 5
- 0
modelscope/pipelines/cv/image_salient_detection_pipeline.py View File

@@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks


@PIPELINES.register_module( @PIPELINES.register_module(
Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) Tasks.semantic_segmentation, module_name=Pipelines.salient_detection)
@PIPELINES.register_module(
Tasks.semantic_segmentation,
module_name=Pipelines.salient_boudary_detection)
@PIPELINES.register_module(
Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection)
class ImageSalientDetectionPipeline(Pipeline): class ImageSalientDetectionPipeline(Pipeline):


def __init__(self, model: str, **kwargs): def __init__(self, model: str, **kwargs):


+ 21
- 0
tests/pipelines/test_salient_detection.py View File

@@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
import cv2 import cv2
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_salient_boudary_detection(self):
input_location = 'data/test/images/image_salient_detection.jpg'
model_id = 'damo/cv_res2net_salient-detection'
salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id)
result = salient_detect(input_location)
import cv2
cv2.imwrite(input_location + '_boudary_salient.jpg',
result[OutputKeys.MASKS])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_camouflag_detection(self):
input_location = 'data/test/images/image_camouflag_detection.jpg'
model_id = 'damo/cv_res2net_camouflaged-detection'
camouflag_detect = pipeline(
Tasks.semantic_segmentation, model=model_id)
result = camouflag_detect(input_location)
import cv2
cv2.imwrite(input_location + '_camouflag.jpg',
result[OutputKeys.MASKS])

@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self): def test_demo_compatibility(self):
self.compatibility_check() self.compatibility_check()


Loading…
Cancel
Save