From 04516276265f27996b2ffb293f3ef6315055d0d7 Mon Sep 17 00:00:00 2001 From: "xingguang.zxg" Date: Sat, 3 Sep 2022 13:21:31 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E5=95=86=E5=93=81=E6=98=BE?= =?UTF-8?q?=E8=91=97=E6=80=A7=E5=88=86=E5=89=B2v1.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 商品显著性检测模型,依赖opencv,mmcv-full Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9909897 --- data/test/images/shop_segmentation.jpg | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 2 +- .../models/cv/shop_segmentation/__init__.py | 1 + .../models/cv/shop_segmentation/common.py | 59 ++ .../models/cv/shop_segmentation/head_fpn.py | 122 +++ .../models/cv/shop_segmentation/models.py | 901 ++++++++++++++++++ .../models/cv/shop_segmentation/neck_fpn.py | 217 +++++ .../cv/shop_segmentation/shop_seg_base.py | 157 +++ .../cv/shop_segmentation/shop_seg_model.py | 115 +++ .../models/cv/shop_segmentation/utils.py | 199 ++++ modelscope/outputs.py | 8 +- modelscope/pipelines/builder.py | 4 +- modelscope/pipelines/cv/__init__.py | 3 +- .../cv/shop_segmentation_pipleline.py | 51 + modelscope/utils/constant.py | 1 + tests/pipelines/test_shop_segmentation.py | 24 + 17 files changed, 1865 insertions(+), 4 deletions(-) create mode 100644 data/test/images/shop_segmentation.jpg create mode 100644 modelscope/models/cv/shop_segmentation/__init__.py create mode 100644 modelscope/models/cv/shop_segmentation/common.py create mode 100644 modelscope/models/cv/shop_segmentation/head_fpn.py create mode 100644 modelscope/models/cv/shop_segmentation/models.py create mode 100644 modelscope/models/cv/shop_segmentation/neck_fpn.py create mode 100644 modelscope/models/cv/shop_segmentation/shop_seg_base.py create mode 100644 modelscope/models/cv/shop_segmentation/shop_seg_model.py create mode 100644 modelscope/models/cv/shop_segmentation/utils.py create mode 100644 modelscope/pipelines/cv/shop_segmentation_pipleline.py create mode 100644 tests/pipelines/test_shop_segmentation.py diff --git a/data/test/images/shop_segmentation.jpg b/data/test/images/shop_segmentation.jpg new file mode 100644 index 00000000..ec02881d --- /dev/null +++ b/data/test/images/shop_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5ecc371c8b0ca09d0e11df89bc549000937eafc451929586426fe657ade25a0 +size 238607 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 06b5a476..b1bf9600 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -32,6 +32,7 @@ class Models(object): vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' text_driven_segmentation = 'text-driven-segmentation' resnet50_bert = 'resnet50-bert' + shop_segmentation = 'shop-segmentation' # EasyCV models yolox = 'YOLOX' @@ -148,6 +149,7 @@ class Pipelines(object): image_reid_person = 'passvitb-image-reid-person' text_driven_segmentation = 'text-driven-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' + shop_segmentation = 'shop-segmentation' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 4db43d17..f2798b59 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -11,7 +11,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_to_image_generation, image_to_image_translation, movie_scene_segmentation, object_detection, product_retrieval_embedding, realtime_object_detection, - salient_detection, super_resolution, + salient_detection, shop_segmentation, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/shop_segmentation/__init__.py b/modelscope/models/cv/shop_segmentation/__init__.py new file mode 100644 index 00000000..b40a0760 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/__init__.py @@ -0,0 +1 @@ +from .shop_seg_base import SHOPSEG diff --git a/modelscope/models/cv/shop_segmentation/common.py b/modelscope/models/cv/shop_segmentation/common.py new file mode 100644 index 00000000..00ba9996 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/common.py @@ -0,0 +1,59 @@ +""" +Base modules are adapted from https://github.com/open-mmlab/mmcv/, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +https://github.com/open-mmlab/mmsegmentation/, +originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +and adapted from https://github.com/raoyongming/DenseCLIP/, +originally MIT License, Copyright (c) 2022 Rao, Yongming. +""" + +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > input_w: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/modelscope/models/cv/shop_segmentation/head_fpn.py b/modelscope/models/cv/shop_segmentation/head_fpn.py new file mode 100644 index 00000000..b3faa9b8 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/head_fpn.py @@ -0,0 +1,122 @@ +""" FPNHead +Base modules are adapted from https://github.com/open-mmlab/mmcv/, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +https://github.com/open-mmlab/mmsegmentation/, +originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +and adapted from https://github.com/raoyongming/DenseCLIP/, +originally MIT License, Copyright (c) 2022 Rao, Yongming. +""" + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from timm.models.layers import drop, drop_path, trunc_normal_ + +from .common import Upsample, resize + + +class FPNHead(nn.Module): + """Panoptic Feature Pyramid Networks. + This head is the implementation of `Semantic FPN + `_. + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, + channels, + num_classes, + dropout_ratio=0.1, + feature_strides=[4, 8, 16, 32], + align_corners=False, + **kwargs): + super(FPNHead, self).__init__() + self.act_cfg = dict(type='ReLU') + self.channels = channels + self.conv_cfg = None + self.norm_cfg = None + self.norm_cfg = dict(type='BN2d', requires_grad=True) + self.align_corners = align_corners + self.dropout_ratio = dropout_ratio + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.in_index = [0, 1, 2, 3] + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + self.apply(self._init_weights) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + inputs = [inputs[i] for i in self.in_index] + return inputs + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def forward(self, inputs): + x = self._transform_inputs(inputs) + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) diff --git a/modelscope/models/cv/shop_segmentation/models.py b/modelscope/models/cv/shop_segmentation/models.py new file mode 100644 index 00000000..8b82d1d1 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/models.py @@ -0,0 +1,901 @@ +""" +Base modules are adapted from https://github.com/open-mmlab/mmcv/, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +https://github.com/open-mmlab/mmsegmentation/, +originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +and adapted from https://github.com/raoyongming/DenseCLIP/, +originally MIT License, Copyright (c) 2022 Rao, Yongming. +""" + +import math +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop, drop_path, trunc_normal_ +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + self.embed_dim = embed_dim + self.spacial_dim = spacial_dim + + def forward(self, x): + B, C, H, W = x.shape + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + + cls_pos = self.positional_embedding[0:1, :] + spatial_pos = F.interpolate( + self.positional_embedding[1:, ].reshape(1, self.spacial_dim, + self.spacial_dim, + self.embed_dim).permute( + 0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape(self.embed_dim, H * W).permute(1, 0) + positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0) + + x = x + positional_embedding[:, None, :] + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + x = x.permute(1, 2, 0) + global_feat = x[:, :, 0] + feature_map = x[:, :, 1:].reshape(B, -1, H, W) + return global_feat, feature_map + + +class CLIPResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim=512, + input_resolution=224, + width=64, + pretrained=None, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in CLIPResNet') + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + + outs = [] + x = self.layer1(x) + outs.append(x) + x = self.layer2(x) + outs.append(x) + x = self.layer3(x) + outs.append(x) + x = self.layer4(x) + outs.append(x) + + return tuple(outs) + + +class CLIPResNetWithAttention(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim=1024, + input_resolution=224, + width=64, + pretrained=None, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, 32, + output_dim) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + if 'positional_embedding' in new_k: + if self.attnpool.positional_embedding.shape != state_dict[ + new_k].shape: + print( + f'Resize the pos_embed shape from {state_dict[new_k].shape}' + f' to {self.attnpool.positional_embedding.shape}' + ) + cls_pos = state_dict[new_k][0:1, :] + H = W = self.input_resolution // 32 + old_h = int( + math.sqrt(state_dict[new_k][1:, ].shape[0])) + spatial_pos = F.interpolate( + state_dict[new_k][1:, ].reshape( + 1, old_h, old_h, + cls_pos.shape[1]).permute(0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape( + cls_pos.shape[1], H * W).permute(1, 0) + positional_embedding = torch.cat( + [cls_pos, spatial_pos], dim=0) + state_dict[new_k] = positional_embedding + assert self.attnpool.positional_embedding.shape == state_dict[ + new_k].shape + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in CLIPResNet') + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + + outs = [] + x = self.layer1(x) + outs.append(x) + x = self.layer2(x) + outs.append(x) + x = self.layer3(x) + outs.append(x) + x = self.layer4(x) + outs.append(x) + + x_global, x_local = self.attnpool(x) + outs.append([x_global, x_local]) + + return tuple(outs) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path=0.): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.drop_path(self.attention(self.ln_1(x))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + drop_path_rate=0.): + super().__init__() + self.width = width + self.layers = layers + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers) + ] # stochastic depth decay rule + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) + for i in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, k, v): + B, N, C = q.shape + assert k.shape == v.shape + B, M, C = k.shape + q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) + + attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale + + attn = attn.softmax(dim=-1) + + x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class TransformerDecoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dropout=0.1, + ): + super().__init__() + self.self_attn = Attention(d_model, nhead, proj_drop=dropout) + self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), + nn.Linear(d_model * 4, d_model)) + + def forward(self, x, mem): + q = k = v = self.norm1(x) + x = x + self.self_attn(q, k, v) + q = self.norm2(x) + x = x + self.cross_attn(q, mem, mem) + x = x + self.dropout(self.mlp(self.norm3(x))) + return x + + +class CLIPVisionTransformer(nn.Module): + + def __init__(self, + input_resolution=224, + patch_size=32, + width=768, + layers=12, + heads=12, + output_dim=512, + drop_path_rate=0.0, + out_indices=[3, 5, 7, 11], + pretrained=None, + get_embeddings=False, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.spatial_size = input_resolution // patch_size + self.ln_pre = LayerNorm(width) + self.get_embeddings = get_embeddings + + self.transformer = Transformer( + width, layers, heads, drop_path_rate=drop_path_rate) + + self.out_indices = out_indices + + if get_embeddings: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + embed_dim = width + + if patch_size == 16: + self.fpn1 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + nn.SyncBatchNorm(embed_dim), + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn3 = nn.GroupNorm(1, embed_dim) + + self.fpn4 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=2, stride=2)) + + elif patch_size == 8: + self.fpn1 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.GroupNorm(1, embed_dim) + + self.fpn3 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.fpn4 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=4, stride=4), + ) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + if 'positional_embedding' in state_dict.keys(): + if self.positional_embedding.shape != state_dict[ + 'positional_embedding'].shape: + print( + f'Resize the pos_embed shape from {state_dict["positional_embedding"].shape} to' + f' {self.positional_embedding.shape}') + cls_pos = state_dict['positional_embedding'][0:1, :] + spatial_pos = F.interpolate( + state_dict['positional_embedding'][1:, ].reshape( + 1, 14, 14, 768).permute(0, 3, 1, 2), + size=(self.spatial_size, self.spatial_size), + mode='bilinear') + spatial_pos = spatial_pos.reshape( + 768, + self.spatial_size * self.spatial_size).permute(1, 0) + positional_embedding = torch.cat([cls_pos, spatial_pos], + dim=0) + state_dict['positional_embedding'] = positional_embedding + assert self.positional_embedding.shape == state_dict[ + 'positional_embedding'].shape + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in vision transformer') + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + B, C, H, W = x.shape + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x1 = self.class_embedding.to(x.dtype) + x2 = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([x1 + x2, x], dim=1) + pos = self.positional_embedding.to(x.dtype) + cls_pos = pos[0, :] + self.class_embedding.to(x.dtype) + spatial_pos = F.interpolate( + pos[1:, ].reshape(1, self.spatial_size, self.spatial_size, + C).permute(0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape(1, C, H * W).permute(0, 2, 1) + pos = torch.cat([cls_pos.reshape(1, 1, C), spatial_pos], dim=1) + x = x + pos + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + + gradientcheckpoint = False + + features = [] + for i, blk in enumerate(self.transformer.resblocks): + if gradientcheckpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + if i in self.out_indices: + xp = x.permute(1, 0, 2)[:, + 1:, :].permute(0, 2, + 1).reshape(B, -1, H, W) + features.append(xp.contiguous()) + + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + if self.get_embeddings: + x = x.permute(1, 0, 2) + x = self.ln_post(x) + x = x @ self.proj + + global_embedding = x[:, 0] + visual_embedding = x[:, 1:].reshape(B, H, W, + -1).permute(0, 3, 1, + 2) # B C H W + + features.append([global_embedding, visual_embedding]) + + return tuple(features) + + +class CLIPTextEncoder(nn.Module): + + def __init__(self, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=1024, + out_dim=256, + pretrained=None, + **kwargs): + super().__init__() + + self.pretrained = pretrained + + self.context_length = context_length + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('transformer.'): + state_dict[k] = checkpoint[k] + + if k == 'positional_embedding' or k == 'text_projection' or k.startswith( + 'token_embedding') or k.startswith('ln_final'): + if k == 'positional_embedding' and checkpoint[k].size( + 0) > self.context_length: + checkpoint[k] = checkpoint[k][:self.context_length] + print('positional_embedding is tuncated from 77 to', + self.context_length) + state_dict[k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in text encoder') + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + x = self.token_embedding(text) + x = x + self.positional_embedding + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_final(x) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1), ...] @ self.text_projection + return x + + +class CLIPTextContextEncoder(nn.Module): + + def __init__(self, + context_length=22, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=1024, + out_dim=256, + pretrained=None, + **kwargs): + super().__init__() + + self.pretrained = pretrained + + self.context_length = context_length + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.embed_dim = embed_dim + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('transformer.'): + state_dict[k] = checkpoint[k] + + if k == 'positional_embedding' or k == 'text_projection' or k.startswith( + 'token_embedding') or k.startswith('ln_final'): + if k == 'positional_embedding' and checkpoint[k].size( + 0) > self.context_length: + checkpoint[k] = checkpoint[k][:self.context_length] + print('positional_embedding is tuncated from 77 to', + self.context_length) + state_dict[k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in text encoder') + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, context=None): + x_text = self.token_embedding(text) # n_clas, n_text, C + K, N1, C = x_text.shape # 150类 * 5??? * 512 + B, N2, C = context.shape # 1 * 8 * 512 + + eos_indx = text.argmax(dim=-1) + N2 + eos_indx = eos_indx.reshape(1, K).expand(B, K).reshape(-1) + + x_text = x_text.reshape(1, K, N1, C).expand(B, K, N1, C) + context = context.reshape(B, 1, N2, C).expand(B, K, N2, C) + + x = torch.cat([x_text[:, :, 0:1], context, x_text[:, :, 1:]], + dim=2).reshape(B * K, N1 + N2, C) + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection + x = x.reshape(B, K, self.embed_dim) + return x + + +class ContextDecoder(nn.Module): + + def __init__(self, + transformer_width=256, + transformer_heads=4, + transformer_layers=6, + visual_dim=1024, + dropout=0.1, + **kwargs): + super().__init__() + + self.memory_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + nn.LayerNorm(transformer_width), + ) + + self.text_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + ) + + self.decoder = nn.ModuleList([ + TransformerDecoderLayer(transformer_width, transformer_heads, + dropout) for _ in range(transformer_layers) + ]) + + self.out_proj = nn.Sequential( + nn.LayerNorm(transformer_width), + nn.Linear(transformer_width, visual_dim)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, text, visual): + B, N, C = visual.shape + visual = self.memory_proj(visual) + x = self.text_proj(text) + + for layer in self.decoder: + x = layer(x, visual) + + return self.out_proj(x) diff --git a/modelscope/models/cv/shop_segmentation/neck_fpn.py b/modelscope/models/cv/shop_segmentation/neck_fpn.py new file mode 100644 index 00000000..108cb043 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/neck_fpn.py @@ -0,0 +1,217 @@ +""" FPNneck +Base modules are adapted from https://github.com/open-mmlab/mmcv/, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +https://github.com/open-mmlab/mmsegmentation/, +originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +and adapted from https://github.com/raoyongming/DenseCLIP/, +originally MIT License, Copyright (c) 2022 Rao, Yongming. +""" + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from timm.models.layers import drop, drop_path, trunc_normal_ + +from .common import resize + + +class FPN(nn.Module): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + super(FPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + self.apply(self._init_weights) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) diff --git a/modelscope/models/cv/shop_segmentation/shop_seg_base.py b/modelscope/models/cv/shop_segmentation/shop_seg_base.py new file mode 100644 index 00000000..e3ae0d54 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/shop_seg_base.py @@ -0,0 +1,157 @@ +""" +Base modules are adapted from https://github.com/open-mmlab/mmcv/, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +https://github.com/open-mmlab/mmsegmentation/, +originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +and adapted from https://github.com/raoyongming/DenseCLIP/, +originally MIT License, Copyright (c) 2022 Rao, Yongming. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .head_fpn import FPNHead +from .models import (CLIPTextContextEncoder, CLIPVisionTransformer, + ContextDecoder) +from .neck_fpn import FPN +from .utils import SimpleTokenizer, tokenize + + +class SHOPSEG(nn.Module): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + model_dir, + context_length=22, + context_feature='attention', + score_concat_index=2, + tau=0.07, + token_embed_dim=512, + text_dim=512, + **args): + super(SHOPSEG, self).__init__() + + self.model_dir = model_dir + self.tokenizer = SimpleTokenizer(model_dir + + '/bpe_simple_vocab_16e6.txt.gz') + + backbone = CLIPVisionTransformer( + input_resolution=1024, + patch_size=16, + width=768, + layers=12, + output_dim=512, + drop_path_rate=0.1, + pretrained=False, + get_embeddings=True) + + text_encoder = CLIPTextContextEncoder( + context_length=30, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=512, + pretrained=False) + + context_decoder = ContextDecoder( + transformer_width=256, + transformer_heads=4, + transformer_layers=3, + visual_dim=512, + dropout=0.1) + neck = FPN( + in_channels=[768, 768, 768 + 2, 768], out_channels=256, num_outs=4) + head_fpd = FPNHead(channels=256, num_classes=2) + + self.backbone = backbone + self.text_encoder = text_encoder + self.context_decoder = context_decoder + self.context_length = context_length + self.score_concat_index = score_concat_index + + self.context_feature = context_feature + self.tau = tau + context_length = self.text_encoder.context_length - self.context_length + self.contexts = nn.Parameter( + torch.randn(1, context_length, token_embed_dim)) + nn.init.trunc_normal_(self.contexts) + self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4) + + self.neck = neck + self.head_fpn = head_fpd + + self.tau = 0.07 + + def encode_text(self, text, context_length): + output = tokenize(self.tokenizer, text, context_length, True) + return output + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + return x + + def after_extract_feat(self, x, name_list): + x_orig = list(x[0:4]) + global_feat, visual_embeddings = x[4] + B, C, H, W = visual_embeddings.shape + if self.context_feature == 'attention': + x1 = global_feat.reshape(B, C, 1) + x2 = visual_embeddings.reshape(B, C, H * W) + visual_context = torch.cat([x1, x2], dim=2).permute(0, 2, 1) + texts = torch.cat([ + self.encode_text(c, context_length=self.context_length) + for c in name_list + ]) + x1 = texts.to(global_feat.device) + x1 = self.text_encoder(x1, self.contexts) + text_embeddings = x1.expand(B, -1, -1) + # update text_embeddings by visual_context! + # (B, 1, C) + text_diff = self.context_decoder(text_embeddings, visual_context) + # (B, K, C) + text_embeddings = text_embeddings + self.gamma * text_diff + + # compute score map and concat + B, K, C = text_embeddings.shape + visual_embeddings = F.normalize(visual_embeddings, dim=1, p=2) + text = F.normalize(text_embeddings, dim=2, p=2) + score_map_list = [] + bsz = B + for i in range(bsz): + ind = 2 * i + sub_text = torch.cat( + [text[i:i + 1, ind:ind + 1], text[i:i + 1, ind + 1:ind + 2]], + dim=1) # 1 * 2 * h * w + + sub_score_map = torch.einsum('bchw,bkc->bkhw', + visual_embeddings[i:i + 1], + sub_text) # 1 * 2 * h * w + score_map_list.append(sub_score_map) + score_map = torch.cat(score_map_list, dim=0) # b * 2 * h * w + x_orig[self.score_concat_index] = torch.cat( + [x_orig[self.score_concat_index], score_map], dim=1) + return x_orig, score_map + + def forward(self, img, text_list=None): + if text_list is None: + bsz = img.size()[0] + text_list = ['foregeound'] * bsz + x = self.extract_feat(img) + _x_orig = [x[i] for i in range(4)] + name_list = [] + for name in text_list: + name_list.append('others') + name_list.append(name[0:20]) + x_orig, score_map = self.after_extract_feat(x, name_list) + x_orig = list(self.neck(x_orig)) + _x_orig = x_orig + pred = self.head_fpn(_x_orig) + return pred diff --git a/modelscope/models/cv/shop_segmentation/shop_seg_model.py b/modelscope/models/cv/shop_segmentation/shop_seg_model.py new file mode 100644 index 00000000..409c583b --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/shop_seg_model.py @@ -0,0 +1,115 @@ +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.shop_segmentation import SHOPSEG +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ShopSegmentation'] + + +@MODELS.register_module( + Tasks.shop_segmentation, module_name=Models.shop_segmentation) +class ShopSegmentation(TorchModel): + """ shop segmentation model. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = SHOPSEG(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + + self.model.load_state_dict(pretrained_params) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + def preprocess(self, img, size=1024): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + h, w, c = img.shape + max_hw = max(h, w) + ratio = 1.0 * size / max_hw + crop_h, crop_w = int(ratio * h), int(ratio * w) + pil_img = Image.fromarray(img) + pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) + np_img = np.array(pil_img, dtype=np.float32) / 255. + + for j in range(3): + np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] + + img_pad = np.zeros((size, size, 3), dtype=np.float32) + img_pad[:crop_h, :crop_w] = np_img + + img_pad = torch.from_numpy(img_pad).permute(2, 0, + 1).unsqueeze(0).float() + return img_pad, h, w, crop_h, crop_w + + def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): + output = np.clip(tensors * 255., a_min=0, a_max=255.) + crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) + + pil_output = Image.fromarray(crop_output) + pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) + np_output = np.array(pil_output, dtype=np.uint8) + + np_output[np_output < 128] = 0 + np_output[np_output >= 128] = 255 + np_output = np.uint8(np_output) + return np_output + + def forward(self, image): + """ + image should be numpy array, dtype=np.uint8, shape: height*width*3 + """ + image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( + image, size=1024) + pred = self.inference(image_tensor) + msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=1024) + + outputs = {OutputKeys.MASKS: msk} + return outputs + + def inference(self, image): + """ + image should be tensor, 1 * 3 * 1024 * 1024 + """ + with torch.no_grad(): + if self.device_id == -1: + output = self.model(image) + else: + device = torch.device('cuda', self.device_id) + output = self.model(image.to(device)) + output = F.interpolate(output, size=(1024, 1024), mode='bilinear') + output = F.softmax(output, dim=1) + output = torch.argmax(output, dim=1) + output = output[0] + if self.device_id == -1: + pred = output.data.numpy() + else: + pred = output.data.cpu().numpy() + + del output + return pred diff --git a/modelscope/models/cv/shop_segmentation/utils.py b/modelscope/models/cv/shop_segmentation/utils.py new file mode 100644 index 00000000..c41f8a65 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/utils.py @@ -0,0 +1,199 @@ +""" CLIP Tokenizer +Adapted from https://github.com/openai/CLIP. +Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import gzip +import html +import os +from functools import lru_cache +from typing import Any, List, Union + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + error_list = [] + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + error_list.append(err) + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +def tokenize(tokenizer, + texts, + context_length: int = 77, + truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = tokenizer.encoder['<|startoftext|>'] + eot_token = tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/outputs.py b/modelscope/outputs.py index e84c8dcc..8fe71ec2 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -259,7 +259,13 @@ TASK_OUTPUTS = { # ] # } Tasks.text_driven_segmentation: [OutputKeys.MASKS], - + # shop segmentation result for single sample + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.shop_segmentation: [OutputKeys.MASKS], # movide scene segmentation result for a single video # { # "split_video_num":3, diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index f43d152b..f6381857 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -156,7 +156,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_vitl16_segmentation_text-driven-seg'), Tasks.movie_scene_segmentation: (Pipelines.movie_scene_segmentation, - 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') + 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), + Tasks.shop_segmentation: (Pipelines.shop_segmentation, + 'damo/cv_vitb16_segmentation_shop-seg'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 9e7d80ee..d3dba978 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -43,10 +43,10 @@ if TYPE_CHECKING: from .tinynas_classification_pipeline import TinynasClassificationPipeline from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline + from .shop_segmentation_pipleline import ShopSegmentationPipeline from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline - else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -96,6 +96,7 @@ else: 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], + 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], 'easycv_pipeline': [ 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', 'Face2DKeypointsPipeline' diff --git a/modelscope/pipelines/cv/shop_segmentation_pipleline.py b/modelscope/pipelines/cv/shop_segmentation_pipleline.py new file mode 100644 index 00000000..b7fd90b4 --- /dev/null +++ b/modelscope/pipelines/cv/shop_segmentation_pipleline.py @@ -0,0 +1,51 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.shop_segmentation, module_name=Pipelines.shop_segmentation) +class ShopSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) + result = { + 'img': img_tensor, + 'ori_h': ori_h, + 'ori_w': ori_w, + 'crop_h': crop_h, + 'crop_w': crop_w + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = { + 'data': outputs, + 'ori_h': input['ori_h'], + 'ori_w': input['ori_w'], + 'crop_h': input['crop_h'], + 'crop_w': input['crop_w'], + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + data = self.model.postprocess(inputs['data'], inputs['crop_h'], + inputs['crop_w'], inputs['ori_h'], + inputs['ori_w']) + outputs = {OutputKeys.MASKS: data} + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 86808ea1..1b738bfe 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -38,6 +38,7 @@ class CVTasks(object): image_segmentation = 'image-segmentation' portrait_matting = 'portrait-matting' text_driven_segmentation = 'text-driven-segmentation' + shop_segmentation = 'shop-segmentation' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_shop_segmentation.py b/tests/pipelines/test_shop_segmentation.py new file mode 100644 index 00000000..58c56dd7 --- /dev/null +++ b/tests/pipelines/test_shop_segmentation.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ShopSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_shop_segmentation(self): + input_location = 'data/test/images/shop_segmentation.jpg' + model_id = 'damo/cv_vitb16_segmentation_shop-seg' + shop_seg = pipeline(Tasks.shop_segmentation, model=model_id) + result = shop_seg(input_location) + import cv2 + # result[OutputKeys.MASKS] is segment map result,other keys are not used + cv2.imwrite(input_location + '_shopseg.jpg', result[OutputKeys.MASKS]) + + +if __name__ == '__main__': + unittest.main()