From ff55bd94364addd74c00d016d94b7bb0babbde56 Mon Sep 17 00:00:00 2001 From: "wendi.hwd" Date: Thu, 24 Nov 2022 10:24:05 +0800 Subject: [PATCH] support camouflaged-detection Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10834768 --- .../test/images/image_camouflag_detection.jpg | 3 + modelscope/metainfo.py | 2 + .../cv/salient_detection/models/__init__.py | 1 + .../models/backbone/Res2Net_v1b.py | 187 ++++++++++++++++++ .../models/backbone/__init__.py | 6 + .../cv/salient_detection/models/modules.py | 178 +++++++++++++++++ .../cv/salient_detection/models/senet.py | 74 +++++++ .../cv/salient_detection/models/utils.py | 105 ++++++++++ .../cv/salient_detection/salient_model.py | 24 ++- .../cv/image_salient_detection_pipeline.py | 5 + tests/pipelines/test_salient_detection.py | 21 ++ 11 files changed, 600 insertions(+), 6 deletions(-) create mode 100644 data/test/images/image_camouflag_detection.jpg create mode 100644 modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py create mode 100644 modelscope/models/cv/salient_detection/models/backbone/__init__.py create mode 100644 modelscope/models/cv/salient_detection/models/modules.py create mode 100644 modelscope/models/cv/salient_detection/models/senet.py create mode 100644 modelscope/models/cv/salient_detection/models/utils.py diff --git a/data/test/images/image_camouflag_detection.jpg b/data/test/images/image_camouflag_detection.jpg new file mode 100644 index 00000000..5029067d --- /dev/null +++ b/data/test/images/image_camouflag_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f +size 19445 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 371cfd34..33b1b3a3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -165,6 +165,8 @@ class Pipelines(object): easycv_segmentation = 'easycv-segmentation' face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment' salient_detection = 'u2net-salient-detection' + salient_boudary_detection = 'res2net-salient-detection' + camouflaged_detection = 'res2net-camouflaged-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' card_detection = 'resnet-card-detection-scrfd34gkps' diff --git a/modelscope/models/cv/salient_detection/models/__init__.py b/modelscope/models/cv/salient_detection/models/__init__.py index 8ea7a5d3..6df5101a 100644 --- a/modelscope/models/cv/salient_detection/models/__init__.py +++ b/modelscope/models/cv/salient_detection/models/__init__.py @@ -1,3 +1,4 @@ # The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License # source code avaiable via https://github.com/xuebinqin/U-2-Net +from .senet import SENet from .u2net import U2NET diff --git a/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py new file mode 100644 index 00000000..40c55773 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py @@ -0,0 +1,187 @@ +# Implementation in this file is modified based on Res2Net-PretrainedModels +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License +# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py +import math + +import torch +import torch.nn as nn + +__all__ = ['Res2Net', 'res2net50_v1b_26w_4s'] + + +class Bottle2neck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + baseWidth=26, + scale=4, + stype='normal'): + """ Constructor + Args: + inplanes: input channel dimensionality + planes: output channel dimensionality + stride: conv stride. Replaces pooling layer. + downsample: None when stride = 1 + baseWidth: basic width of conv3x3 + scale: number of scale. + type: 'normal': normal set. 'stage': first block of a new stage. + """ + super(Bottle2neck, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = nn.Conv2d( + inplanes, width * scale, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(width * scale) + if scale == 1: + self.nums = 1 + else: + self.nums = scale - 1 + if stype == 'stage': + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + convs = [] + bns = [] + for i in range(self.nums): + convs.append( + nn.Conv2d( + width, + width, + kernel_size=3, + stride=stride, + padding=1, + bias=False)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.conv3 = nn.Conv2d( + width * scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stype = stype + self.scale = scale + self.width = width + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0 or self.stype == 'stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + if self.scale != 1 and self.stype == 'normal': + out = torch.cat((out, spx[self.nums]), 1) + elif self.scale != 1 and self.stype == 'stage': + out = torch.cat((out, self.pool(spx[self.nums])), 1) + out = self.conv3(out) + out = self.bn3(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class Res2Net(nn.Module): + + def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): + self.inplanes = 64 + super(Res2Net, self).__init__() + self.baseWidth = baseWidth + self.scale = scale + self.conv1 = nn.Sequential( + nn.Conv2d(3, 32, 3, 2, 1, bias=False), nn.BatchNorm2d(32), + nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1, bias=False), + nn.BatchNorm2d(32), nn.ReLU(inplace=True), + nn.Conv2d(32, 64, 3, 1, 1, bias=False)) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + stype='stage', + baseWidth=self.baseWidth, + scale=self.scale)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + baseWidth=self.baseWidth, + scale=self.scale)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +def res2net50_v1b_26w_4s(backbone_path, pretrained=False, **kwargs): + """Constructs a Res2Net-50_v1b_26w_4s lib. + Args: + pretrained (bool): If True, returns a lib pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) + if pretrained: + model_state = torch.load(backbone_path) + model.load_state_dict(model_state) + return model diff --git a/modelscope/models/cv/salient_detection/models/backbone/__init__.py b/modelscope/models/cv/salient_detection/models/backbone/__init__.py new file mode 100644 index 00000000..52d5ded1 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/backbone/__init__.py @@ -0,0 +1,6 @@ +# Implementation in this file is modified based on Res2Net-PretrainedModels +# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License +# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py +from .Res2Net_v1b import res2net50_v1b_26w_4s + +__all__ = ['res2net50_v1b_26w_4s'] diff --git a/modelscope/models/cv/salient_detection/models/modules.py b/modelscope/models/cv/salient_detection/models/modules.py new file mode 100644 index 00000000..09796bd3 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/modules.py @@ -0,0 +1,178 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import ConvBNReLU + + +class AreaLayer(nn.Module): + + def __init__(self, in_channel, out_channel): + super(AreaLayer, self).__init__() + self.lbody = nn.Sequential( + nn.Conv2d(out_channel, out_channel, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) + self.hbody = nn.Sequential( + nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), + nn.ReLU(inplace=True)) + self.body = nn.Sequential( + nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, 1, 1)) + + def forward(self, xl, xh): + xl1 = self.lbody(xl) + xl1 = F.interpolate( + xl1, size=xh.size()[2:], mode='bilinear', align_corners=True) + xh1 = self.hbody(xh) + x = torch.cat((xl1, xh1), dim=1) + x_out = self.body(x) + return x_out + + +class EdgeLayer(nn.Module): + + def __init__(self, in_channel, out_channel): + super(EdgeLayer, self).__init__() + self.lbody = nn.Sequential( + nn.Conv2d(out_channel, out_channel, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) + self.hbody = nn.Sequential( + nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), + nn.ReLU(inplace=True)) + self.bodye = nn.Sequential( + nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, 1, 1)) + + def forward(self, xl, xh): + xl1 = self.lbody(xl) + xh1 = self.hbody(xh) + xh1 = F.interpolate( + xh1, size=xl.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((xl1, xh1), dim=1) + x_out = self.bodye(x) + return x_out + + +class EBlock(nn.Module): + + def __init__(self, inchs, outchs): + super(EBlock, self).__init__() + self.elayer = nn.Sequential( + ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1), + ConvBNReLU(outchs, outchs, 1)) + self.salayer = nn.Sequential( + nn.Conv2d(2, 1, 3, 1, 1, bias=False), + nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid()) + + def forward(self, x, edgeAtten): + x = torch.cat((x, edgeAtten), dim=1) + ex = self.elayer(x) + ex_max = torch.max(ex, 1, keepdim=True)[0] + ex_mean = torch.mean(ex, dim=1, keepdim=True) + xei_compress = torch.cat((ex_max, ex_mean), dim=1) + + scale = self.salayer(xei_compress) + x_out = ex * scale + return x_out + + +class StructureE(nn.Module): + + def __init__(self, inchs, outchs, EM): + super(StructureE, self).__init__() + self.ne_modules = int(inchs / EM) + NM = int(outchs / self.ne_modules) + elayes = [] + for i in range(self.ne_modules): + emblock = EBlock(EM, NM) + elayes.append(emblock) + self.emlayes = nn.ModuleList(elayes) + self.body = nn.Sequential( + ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1)) + + def forward(self, x, edgeAtten): + if edgeAtten.size() != x.size(): + edgeAtten = F.interpolate( + edgeAtten, x.size()[2:], mode='bilinear', align_corners=False) + xx = torch.chunk(x, self.ne_modules, dim=1) + efeas = [] + for i in range(self.ne_modules): + xei = self.emlayes[i](xx[i], edgeAtten) + efeas.append(xei) + efeas = torch.cat(efeas, dim=1) + x_out = self.body(efeas) + return x_out + + +class ABlock(nn.Module): + + def __init__(self, inchs, outchs, k): + super(ABlock, self).__init__() + self.alayer = nn.Sequential( + ConvBNReLU(inchs, outchs, k, 1, k // 2), + ConvBNReLU(outchs, outchs, 1)) + self.arlayer = nn.Sequential( + ConvBNReLU(inchs, outchs, k, 1, k // 2), + ConvBNReLU(outchs, outchs, 1)) + self.fusion = ConvBNReLU(2 * outchs, outchs, 1) + + def forward(self, x, areaAtten): + xa = x * areaAtten + xra = x * (1 - areaAtten) + xout = self.fusion(torch.cat((xa, xra), dim=1)) + return xout + + +class AMFusion(nn.Module): + + def __init__(self, inchs, outchs, AM): + super(AMFusion, self).__init__() + self.k = [3, 3, 5, 5] + self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1) + self.up = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=True) + self.na_modules = int(outchs / AM) + alayers = [] + for i in range(self.na_modules): + layer = ABlock(AM, AM, self.k[i]) + alayers.append(layer) + self.alayers = nn.ModuleList(alayers) + self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1) + self.fusion_e = nn.Sequential( + nn.Conv2d( + outchs, outchs, kernel_size=(3, 1), padding=(1, 0), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), + nn.Conv2d( + outchs, outchs, kernel_size=(1, 3), padding=(0, 1), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) + self.fusion_e1 = nn.Sequential( + nn.Conv2d( + outchs, outchs, kernel_size=(5, 1), padding=(2, 0), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), + nn.Conv2d( + outchs, outchs, kernel_size=(1, 5), padding=(0, 2), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) + self.fusion = ConvBNReLU(3 * outchs, outchs, 1) + + def forward(self, xl, xh, xhm): + xh1 = self.up(self.conv_up(xh)) + x = xh1 + xl + xm = self.up(torch.sigmoid(xhm)) + xx = torch.chunk(x, self.na_modules, dim=1) + xxmids = [] + for i in range(self.na_modules): + xi = self.alayers[i](xx[i], xm) + xxmids.append(xi) + xfea = torch.cat(xxmids, dim=1) + x0 = self.fusion_0(xfea) + x1 = self.fusion_e(xfea) + x2 = self.fusion_e1(xfea) + x_out = self.fusion(torch.cat((x0, x1, x2), dim=1)) + return x_out diff --git a/modelscope/models/cv/salient_detection/models/senet.py b/modelscope/models/cv/salient_detection/models/senet.py new file mode 100644 index 00000000..37cf42be --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/senet.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import res2net50_v1b_26w_4s as res2net +from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE +from .utils import ASPP, CBAM, ConvBNReLU + + +class SENet(nn.Module): + + def __init__(self, backbone_path=None, pretrained=False): + super(SENet, self).__init__() + resnet50 = res2net(backbone_path, pretrained) + self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1, + resnet50.relu) + self.maxpool = resnet50.maxpool + self.layer1 = resnet50.layer1 + self.layer2 = resnet50.layer2 + self.layer3 = resnet50.layer3 + self.layer4 = resnet50.layer4 + self.aspp3 = ASPP(1024, 256) + self.aspp4 = ASPP(2048, 256) + self.cbblock3 = CBAM(inchs=256, kernel_size=5) + self.cbblock4 = CBAM(inchs=256, kernel_size=5) + self.up = nn.Upsample( + mode='bilinear', scale_factor=2, align_corners=False) + self.conv_up = ConvBNReLU(512, 512, 1) + self.aux_edge = EdgeLayer(512, 256) + self.aux_area = AreaLayer(512, 256) + self.layer1_enhance = StructureE(256, 128, 128) + self.layer2_enhance = StructureE(512, 256, 128) + self.layer3_decoder = AMFusion(512, 256, 128) + self.layer2_decoder = AMFusion(256, 128, 128) + self.out_conv_8 = nn.Conv2d(256, 1, 1) + self.out_conv_4 = nn.Conv2d(128, 1, 1) + + def forward(self, x): + layer0 = self.layer0_1(x) + layer0s = self.maxpool(layer0) + layer1 = self.layer1(layer0s) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + layer3_eh = self.cbblock3(self.aspp3(layer3)) + layer4_eh = self.cbblock4(self.aspp4(layer4)) + layer34 = self.conv_up( + torch.cat((self.up(layer4_eh), layer3_eh), dim=1)) + edge_atten = self.aux_edge(layer1, layer34) + area_atten = self.aux_area(layer1, layer34) + edge_atten_ = torch.sigmoid(edge_atten) + layer1_eh = self.layer1_enhance(layer1, edge_atten_) + layer2_eh = self.layer2_enhance(layer2, edge_atten_) + layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten) + out_8 = self.out_conv_8(layer2_fu) + layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8) + out_4 = self.out_conv_4(layer1_fu) + out_16 = F.interpolate( + area_atten, + size=x.size()[2:], + mode='bilinear', + align_corners=False) + out_8 = F.interpolate( + out_8, size=x.size()[2:], mode='bilinear', align_corners=False) + out_4 = F.interpolate( + out_4, size=x.size()[2:], mode='bilinear', align_corners=False) + edge_out = F.interpolate( + edge_atten_, + size=x.size()[2:], + mode='bilinear', + align_corners=False) + + return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out diff --git a/modelscope/models/cv/salient_detection/models/utils.py b/modelscope/models/cv/salient_detection/models/utils.py new file mode 100644 index 00000000..292ee914 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/utils.py @@ -0,0 +1,105 @@ +# Implementation in this file is modified based on deeplabv3 +# Originally MIT license,publicly avaialbe at https://github.com/fregu856/deeplabv3/blob/master/model/aspp.py +# Implementation in this file is modified based on attention-module +# Originally MIT license,publicly avaialbe at https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py +import torch +import torch.nn as nn + + +class ConvBNReLU(nn.Module): + + def __init__(self, + inplanes, + planes, + kernel_size=3, + stride=1, + padding=0, + dilation=1, + bias=False): + super(ConvBNReLU, self).__init__() + self.block = nn.Sequential( + nn.Conv2d( + inplanes, + planes, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias), nn.BatchNorm2d(planes), nn.ReLU(inplace=True)) + + def forward(self, x): + return self.block(x) + + +class ASPP(nn.Module): + + def __init__(self, in_dim, out_dim): + super(ASPP, self).__init__() + mid_dim = 128 + self.conv1 = ConvBNReLU(in_dim, mid_dim, kernel_size=1) + self.conv2 = ConvBNReLU( + in_dim, mid_dim, kernel_size=3, padding=2, dilation=2) + self.conv3 = ConvBNReLU( + in_dim, mid_dim, kernel_size=3, padding=5, dilation=5) + self.conv4 = ConvBNReLU( + in_dim, mid_dim, kernel_size=3, padding=7, dilation=7) + self.conv5 = ConvBNReLU(in_dim, mid_dim, kernel_size=1, padding=0) + self.fuse = ConvBNReLU(5 * mid_dim, out_dim, 3, 1, 1) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + + def forward(self, x): + conv1 = self.conv1(x) + conv2 = self.conv2(x) + conv3 = self.conv3(x) + conv4 = self.conv4(x) + xg = self.conv5(self.global_pooling(x)) + conv5 = nn.Upsample((x.shape[2], x.shape[3]), mode='nearest')(xg) + return self.fuse(torch.cat((conv1, conv2, conv3, conv4, conv5), 1)) + + +class ChannelAttention(nn.Module): + + def __init__(self, inchs, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(inchs, inchs // 16, 1, bias=False), nn.ReLU(), + nn.Conv2d(inchs // 16, inchs, 1, bias=False)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return self.sigmoid(out) + + +class SpatialAttention(nn.Module): + + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d( + 2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class CBAM(nn.Module): + + def __init__(self, inchs, kernel_size=7): + super().__init__() + self.calayer = ChannelAttention(inchs=inchs) + self.saLayer = SpatialAttention(kernel_size=kernel_size) + + def forward(self, x): + xca = self.calayer(x) * x + xsa = self.saLayer(xca) * xca + return xsa diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py index 73c3c3fb..e25166c8 100644 --- a/modelscope/models/cv/salient_detection/salient_model.py +++ b/modelscope/models/cv/salient_detection/salient_model.py @@ -2,7 +2,6 @@ import os.path as osp import cv2 -import numpy as np import torch from PIL import Image from torchvision import transforms @@ -10,8 +9,9 @@ from torchvision import transforms from modelscope.metainfo import Models from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS +from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks -from .models import U2NET +from .models import U2NET, SENet @MODELS.register_module( @@ -22,13 +22,25 @@ class SalientDetection(TorchModel): """str -- model file root.""" super().__init__(model_dir, *args, **kwargs) model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) - self.model = U2NET(3, 1) + + self.norm_mean = [0.485, 0.456, 0.406] + self.norm_std = [0.229, 0.224, 0.225] + self.norm_size = (320, 320) + + config_path = osp.join(model_dir, 'config.py') + if osp.exists(config_path) is False: + self.model = U2NET(3, 1) + else: + self.model = SENet(backbone_path=None, pretrained=False) + config = Config.from_file(config_path) + self.norm_mean = config.norm_mean + self.norm_std = config.norm_std + self.norm_size = config.norm_size checkpoint = torch.load(model_path, map_location='cpu') self.transform_input = transforms.Compose([ - transforms.Resize((320, 320)), + transforms.Resize(self.norm_size), transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transforms.Normalize(mean=self.norm_mean, std=self.norm_std) ]) self.model.load_state_dict(checkpoint) self.model.eval() diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py index 4a3eaa65..4b4df52c 100644 --- a/modelscope/pipelines/cv/image_salient_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks @PIPELINES.register_module( Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) +@PIPELINES.register_module( + Tasks.semantic_segmentation, + module_name=Pipelines.salient_boudary_detection) +@PIPELINES.register_module( + Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection) class ImageSalientDetectionPipeline(Pipeline): def __init__(self, model: str, **kwargs): diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py index bcb904e6..3101213c 100644 --- a/tests/pipelines/test_salient_detection.py +++ b/tests/pipelines/test_salient_detection.py @@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): import cv2 cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_salient_boudary_detection(self): + input_location = 'data/test/images/image_salient_detection.jpg' + model_id = 'damo/cv_res2net_salient-detection' + salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) + result = salient_detect(input_location) + import cv2 + cv2.imwrite(input_location + '_boudary_salient.jpg', + result[OutputKeys.MASKS]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_camouflag_detection(self): + input_location = 'data/test/images/image_camouflag_detection.jpg' + model_id = 'damo/cv_res2net_camouflaged-detection' + camouflag_detect = pipeline( + Tasks.semantic_segmentation, model=model_id) + result = camouflag_detect(input_location) + import cv2 + cv2.imwrite(input_location + '_camouflag.jpg', + result[OutputKeys.MASKS]) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check()