商品显著性检测模型,依赖opencv,mmcv-full
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9909897
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f5ecc371c8b0ca09d0e11df89bc549000937eafc451929586426fe657ade25a0 | |||||
| size 238607 | |||||
| @@ -32,6 +32,7 @@ class Models(object): | |||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| shop_segmentation = 'shop-segmentation' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| @@ -148,6 +149,7 @@ class Pipelines(object): | |||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -11,7 +11,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| image_to_image_generation, image_to_image_translation, | image_to_image_generation, image_to_image_translation, | ||||
| movie_scene_segmentation, object_detection, | movie_scene_segmentation, object_detection, | ||||
| product_retrieval_embedding, realtime_object_detection, | product_retrieval_embedding, realtime_object_detection, | ||||
| salient_detection, super_resolution, | |||||
| salient_detection, shop_segmentation, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | video_single_object_tracking, video_summarization, virual_tryon) | ||||
| # yapf: enable | # yapf: enable | ||||
| @@ -0,0 +1 @@ | |||||
| from .shop_seg_base import SHOPSEG | |||||
| @@ -0,0 +1,59 @@ | |||||
| """ | |||||
| Base modules are adapted from https://github.com/open-mmlab/mmcv/, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| https://github.com/open-mmlab/mmsegmentation/, | |||||
| originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, | |||||
| and adapted from https://github.com/raoyongming/DenseCLIP/, | |||||
| originally MIT License, Copyright (c) 2022 Rao, Yongming. | |||||
| """ | |||||
| import warnings | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| def resize(input, | |||||
| size=None, | |||||
| scale_factor=None, | |||||
| mode='nearest', | |||||
| align_corners=None, | |||||
| warning=True): | |||||
| if warning: | |||||
| if size is not None and align_corners: | |||||
| input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |||||
| output_h, output_w = tuple(int(x) for x in size) | |||||
| if output_h > input_h or output_w > input_w: | |||||
| if ((output_h > 1 and output_w > 1 and input_h > 1 | |||||
| and input_w > 1) and (output_h - 1) % (input_h - 1) | |||||
| and (output_w - 1) % (input_w - 1)): | |||||
| warnings.warn( | |||||
| f'When align_corners={align_corners}, ' | |||||
| 'the output would more aligned if ' | |||||
| f'input size {(input_h, input_w)} is `x+1` and ' | |||||
| f'out size {(output_h, output_w)} is `nx+1`') | |||||
| return F.interpolate(input, size, scale_factor, mode, align_corners) | |||||
| class Upsample(nn.Module): | |||||
| def __init__(self, | |||||
| size=None, | |||||
| scale_factor=None, | |||||
| mode='nearest', | |||||
| align_corners=None): | |||||
| super(Upsample, self).__init__() | |||||
| self.size = size | |||||
| if isinstance(scale_factor, tuple): | |||||
| self.scale_factor = tuple(float(factor) for factor in scale_factor) | |||||
| else: | |||||
| self.scale_factor = float(scale_factor) if scale_factor else None | |||||
| self.mode = mode | |||||
| self.align_corners = align_corners | |||||
| def forward(self, x): | |||||
| if not self.size: | |||||
| size = [int(t * self.scale_factor) for t in x.shape[-2:]] | |||||
| else: | |||||
| size = self.size | |||||
| return resize(x, size, None, self.mode, self.align_corners) | |||||
| @@ -0,0 +1,122 @@ | |||||
| """ FPNHead | |||||
| Base modules are adapted from https://github.com/open-mmlab/mmcv/, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| https://github.com/open-mmlab/mmsegmentation/, | |||||
| originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, | |||||
| and adapted from https://github.com/raoyongming/DenseCLIP/, | |||||
| originally MIT License, Copyright (c) 2022 Rao, Yongming. | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from mmcv.cnn import ConvModule | |||||
| from timm.models.layers import drop, drop_path, trunc_normal_ | |||||
| from .common import Upsample, resize | |||||
| class FPNHead(nn.Module): | |||||
| """Panoptic Feature Pyramid Networks. | |||||
| This head is the implementation of `Semantic FPN | |||||
| <https://arxiv.org/abs/1901.02446>`_. | |||||
| Args: | |||||
| feature_strides (tuple[int]): The strides for input feature maps. | |||||
| stack_lateral. All strides suppose to be power of 2. The first | |||||
| one is of largest resolution. | |||||
| """ | |||||
| def __init__(self, | |||||
| channels, | |||||
| num_classes, | |||||
| dropout_ratio=0.1, | |||||
| feature_strides=[4, 8, 16, 32], | |||||
| align_corners=False, | |||||
| **kwargs): | |||||
| super(FPNHead, self).__init__() | |||||
| self.act_cfg = dict(type='ReLU') | |||||
| self.channels = channels | |||||
| self.conv_cfg = None | |||||
| self.norm_cfg = None | |||||
| self.norm_cfg = dict(type='BN2d', requires_grad=True) | |||||
| self.align_corners = align_corners | |||||
| self.dropout_ratio = dropout_ratio | |||||
| self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |||||
| if dropout_ratio > 0: | |||||
| self.dropout = nn.Dropout2d(dropout_ratio) | |||||
| else: | |||||
| self.dropout = None | |||||
| self.in_index = [0, 1, 2, 3] | |||||
| assert min(feature_strides) == feature_strides[0] | |||||
| self.feature_strides = feature_strides | |||||
| self.scale_heads = nn.ModuleList() | |||||
| for i in range(len(feature_strides)): | |||||
| head_length = max( | |||||
| 1, | |||||
| int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) | |||||
| scale_head = [] | |||||
| for k in range(head_length): | |||||
| scale_head.append( | |||||
| ConvModule( | |||||
| self.channels, | |||||
| self.channels, | |||||
| 3, | |||||
| padding=1, | |||||
| conv_cfg=self.conv_cfg, | |||||
| norm_cfg=self.norm_cfg, | |||||
| act_cfg=self.act_cfg)) | |||||
| if feature_strides[i] != feature_strides[0]: | |||||
| scale_head.append( | |||||
| Upsample( | |||||
| scale_factor=2, | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners)) | |||||
| self.scale_heads.append(nn.Sequential(*scale_head)) | |||||
| self.apply(self._init_weights) | |||||
| def _transform_inputs(self, inputs): | |||||
| """Transform inputs for decoder. | |||||
| Args: | |||||
| inputs (list[Tensor]): List of multi-level img features. | |||||
| Returns: | |||||
| Tensor: The transformed inputs | |||||
| """ | |||||
| inputs = [inputs[i] for i in self.in_index] | |||||
| return inputs | |||||
| def cls_seg(self, feat): | |||||
| """Classify each pixel.""" | |||||
| if self.dropout is not None: | |||||
| feat = self.dropout(feat) | |||||
| output = self.conv_seg(feat) | |||||
| return output | |||||
| def forward(self, inputs): | |||||
| x = self._transform_inputs(inputs) | |||||
| output = self.scale_heads[0](x[0]) | |||||
| for i in range(1, len(self.feature_strides)): | |||||
| # non inplace | |||||
| output = output + resize( | |||||
| self.scale_heads[i](x[i]), | |||||
| size=output.shape[2:], | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners) | |||||
| output = self.cls_seg(output) | |||||
| return output | |||||
| def _init_weights(self, m): | |||||
| if isinstance(m, nn.Linear): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.LayerNorm): | |||||
| nn.init.constant_(m.bias, 0) | |||||
| nn.init.constant_(m.weight, 1.0) | |||||
| elif isinstance(m, nn.Conv2d): | |||||
| nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') | |||||
| if m.bias is not None: | |||||
| nn.init.constant_(m.bias.data, 0) | |||||
| @@ -0,0 +1,901 @@ | |||||
| """ | |||||
| Base modules are adapted from https://github.com/open-mmlab/mmcv/, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| https://github.com/open-mmlab/mmsegmentation/, | |||||
| originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, | |||||
| and adapted from https://github.com/raoyongming/DenseCLIP/, | |||||
| originally MIT License, Copyright (c) 2022 Rao, Yongming. | |||||
| """ | |||||
| import math | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint as checkpoint | |||||
| from timm.models.layers import drop, drop_path, trunc_normal_ | |||||
| from torch import nn | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1): | |||||
| super().__init__() | |||||
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(planes) | |||||
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(planes) | |||||
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||||
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = None | |||||
| self.stride = stride | |||||
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||||
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||||
| self.downsample = nn.Sequential( | |||||
| OrderedDict([('-1', nn.AvgPool2d(stride)), | |||||
| ('0', | |||||
| nn.Conv2d( | |||||
| inplanes, | |||||
| planes * self.expansion, | |||||
| 1, | |||||
| stride=1, | |||||
| bias=False)), | |||||
| ('1', nn.BatchNorm2d(planes * self.expansion))])) | |||||
| def forward(self, x: torch.Tensor): | |||||
| identity = x | |||||
| out = self.relu(self.bn1(self.conv1(x))) | |||||
| out = self.relu(self.bn2(self.conv2(out))) | |||||
| out = self.avgpool(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class AttentionPool2d(nn.Module): | |||||
| def __init__(self, | |||||
| spacial_dim: int, | |||||
| embed_dim: int, | |||||
| num_heads: int, | |||||
| output_dim: int = None): | |||||
| super().__init__() | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||||
| self.num_heads = num_heads | |||||
| self.embed_dim = embed_dim | |||||
| self.spacial_dim = spacial_dim | |||||
| def forward(self, x): | |||||
| B, C, H, W = x.shape | |||||
| x = x.reshape(x.shape[0], x.shape[1], | |||||
| x.shape[2] * x.shape[3]).permute(2, 0, | |||||
| 1) # NCHW -> (HW)NC | |||||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||||
| cls_pos = self.positional_embedding[0:1, :] | |||||
| spatial_pos = F.interpolate( | |||||
| self.positional_embedding[1:, ].reshape(1, self.spacial_dim, | |||||
| self.spacial_dim, | |||||
| self.embed_dim).permute( | |||||
| 0, 3, 1, 2), | |||||
| size=(H, W), | |||||
| mode='bilinear') | |||||
| spatial_pos = spatial_pos.reshape(self.embed_dim, H * W).permute(1, 0) | |||||
| positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0) | |||||
| x = x + positional_embedding[:, None, :] | |||||
| x, _ = F.multi_head_attention_forward( | |||||
| query=x, | |||||
| key=x, | |||||
| value=x, | |||||
| embed_dim_to_check=x.shape[-1], | |||||
| num_heads=self.num_heads, | |||||
| q_proj_weight=self.q_proj.weight, | |||||
| k_proj_weight=self.k_proj.weight, | |||||
| v_proj_weight=self.v_proj.weight, | |||||
| in_proj_weight=None, | |||||
| in_proj_bias=torch.cat( | |||||
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||||
| bias_k=None, | |||||
| bias_v=None, | |||||
| add_zero_attn=False, | |||||
| dropout_p=0, | |||||
| out_proj_weight=self.c_proj.weight, | |||||
| out_proj_bias=self.c_proj.bias, | |||||
| use_separate_proj_weight=True, | |||||
| training=self.training, | |||||
| need_weights=False) | |||||
| x = x.permute(1, 2, 0) | |||||
| global_feat = x[:, :, 0] | |||||
| feature_map = x[:, :, 1:].reshape(B, -1, H, W) | |||||
| return global_feat, feature_map | |||||
| class CLIPResNet(nn.Module): | |||||
| """ | |||||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||||
| - The final pooling layer is a QKV attention instead of an average pool | |||||
| """ | |||||
| def __init__(self, | |||||
| layers, | |||||
| output_dim=512, | |||||
| input_resolution=224, | |||||
| width=64, | |||||
| pretrained=None, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.pretrained = pretrained | |||||
| self.output_dim = output_dim | |||||
| self.input_resolution = input_resolution | |||||
| # the 3-layer stem | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||||
| self.conv2 = nn.Conv2d( | |||||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||||
| self.conv3 = nn.Conv2d( | |||||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(width) | |||||
| self.avgpool = nn.AvgPool2d(2) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| # residual layers | |||||
| self._inplanes = width # this is a *mutable* variable used during construction | |||||
| self.layer1 = self._make_layer(width, layers[0]) | |||||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||||
| def init_weights(self, pretrained=None): | |||||
| pretrained = pretrained or self.pretrained | |||||
| if isinstance(pretrained, str): | |||||
| checkpoint = torch.jit.load( | |||||
| pretrained, map_location='cpu').float().state_dict() | |||||
| state_dict = {} | |||||
| for k in checkpoint.keys(): | |||||
| if k.startswith('visual.'): | |||||
| new_k = k.replace('visual.', '') | |||||
| state_dict[new_k] = checkpoint[k] | |||||
| u, w = self.load_state_dict(state_dict, False) | |||||
| print(u, w, 'are misaligned params in CLIPResNet') | |||||
| def _make_layer(self, planes, blocks, stride=1): | |||||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||||
| self._inplanes = planes * Bottleneck.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append(Bottleneck(self._inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| def stem(x): | |||||
| for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||||
| (self.conv3, self.bn3)]: | |||||
| x = self.relu(bn(conv(x))) | |||||
| x = self.avgpool(x) | |||||
| return x | |||||
| x = x.type(self.conv1.weight.dtype) | |||||
| x = stem(x) | |||||
| outs = [] | |||||
| x = self.layer1(x) | |||||
| outs.append(x) | |||||
| x = self.layer2(x) | |||||
| outs.append(x) | |||||
| x = self.layer3(x) | |||||
| outs.append(x) | |||||
| x = self.layer4(x) | |||||
| outs.append(x) | |||||
| return tuple(outs) | |||||
| class CLIPResNetWithAttention(nn.Module): | |||||
| """ | |||||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||||
| - The final pooling layer is a QKV attention instead of an average pool | |||||
| """ | |||||
| def __init__(self, | |||||
| layers, | |||||
| output_dim=1024, | |||||
| input_resolution=224, | |||||
| width=64, | |||||
| pretrained=None, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.pretrained = pretrained | |||||
| self.output_dim = output_dim | |||||
| self.input_resolution = input_resolution | |||||
| # the 3-layer stem | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||||
| self.conv2 = nn.Conv2d( | |||||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||||
| self.conv3 = nn.Conv2d( | |||||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(width) | |||||
| self.avgpool = nn.AvgPool2d(2) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| # residual layers | |||||
| self._inplanes = width # this is a *mutable* variable used during construction | |||||
| self.layer1 = self._make_layer(width, layers[0]) | |||||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||||
| embed_dim = width * 32 # the ResNet feature dimension | |||||
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, 32, | |||||
| output_dim) | |||||
| def init_weights(self, pretrained=None): | |||||
| pretrained = pretrained or self.pretrained | |||||
| if isinstance(pretrained, str): | |||||
| checkpoint = torch.jit.load( | |||||
| pretrained, map_location='cpu').float().state_dict() | |||||
| state_dict = {} | |||||
| for k in checkpoint.keys(): | |||||
| if k.startswith('visual.'): | |||||
| new_k = k.replace('visual.', '') | |||||
| state_dict[new_k] = checkpoint[k] | |||||
| if 'positional_embedding' in new_k: | |||||
| if self.attnpool.positional_embedding.shape != state_dict[ | |||||
| new_k].shape: | |||||
| print( | |||||
| f'Resize the pos_embed shape from {state_dict[new_k].shape}' | |||||
| f' to {self.attnpool.positional_embedding.shape}' | |||||
| ) | |||||
| cls_pos = state_dict[new_k][0:1, :] | |||||
| H = W = self.input_resolution // 32 | |||||
| old_h = int( | |||||
| math.sqrt(state_dict[new_k][1:, ].shape[0])) | |||||
| spatial_pos = F.interpolate( | |||||
| state_dict[new_k][1:, ].reshape( | |||||
| 1, old_h, old_h, | |||||
| cls_pos.shape[1]).permute(0, 3, 1, 2), | |||||
| size=(H, W), | |||||
| mode='bilinear') | |||||
| spatial_pos = spatial_pos.reshape( | |||||
| cls_pos.shape[1], H * W).permute(1, 0) | |||||
| positional_embedding = torch.cat( | |||||
| [cls_pos, spatial_pos], dim=0) | |||||
| state_dict[new_k] = positional_embedding | |||||
| assert self.attnpool.positional_embedding.shape == state_dict[ | |||||
| new_k].shape | |||||
| u, w = self.load_state_dict(state_dict, False) | |||||
| print(u, w, 'are misaligned params in CLIPResNet') | |||||
| def _make_layer(self, planes, blocks, stride=1): | |||||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||||
| self._inplanes = planes * Bottleneck.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append(Bottleneck(self._inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| def stem(x): | |||||
| for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), | |||||
| (self.conv3, self.bn3)]: | |||||
| x = self.relu(bn(conv(x))) | |||||
| x = self.avgpool(x) | |||||
| return x | |||||
| x = x.type(self.conv1.weight.dtype) | |||||
| x = stem(x) | |||||
| outs = [] | |||||
| x = self.layer1(x) | |||||
| outs.append(x) | |||||
| x = self.layer2(x) | |||||
| outs.append(x) | |||||
| x = self.layer3(x) | |||||
| outs.append(x) | |||||
| x = self.layer4(x) | |||||
| outs.append(x) | |||||
| x_global, x_local = self.attnpool(x) | |||||
| outs.append([x_global, x_local]) | |||||
| return tuple(outs) | |||||
| class LayerNorm(nn.LayerNorm): | |||||
| """Subclass torch's LayerNorm to handle fp16.""" | |||||
| def forward(self, x: torch.Tensor): | |||||
| orig_type = x.dtype | |||||
| ret = super().forward(x.type(torch.float32)) | |||||
| return ret.type(orig_type) | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x: torch.Tensor): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class DropPath(nn.Module): | |||||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||||
| """ | |||||
| def __init__(self, drop_prob=None): | |||||
| super(DropPath, self).__init__() | |||||
| self.drop_prob = drop_prob | |||||
| def forward(self, x): | |||||
| return drop_path(x, self.drop_prob, self.training) | |||||
| def extra_repr(self) -> str: | |||||
| return 'p={}'.format(self.drop_prob) | |||||
| class ResidualAttentionBlock(nn.Module): | |||||
| def __init__(self, | |||||
| d_model: int, | |||||
| n_head: int, | |||||
| attn_mask: torch.Tensor = None, | |||||
| drop_path=0.): | |||||
| super().__init__() | |||||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||||
| self.ln_1 = LayerNorm(d_model) | |||||
| self.mlp = nn.Sequential( | |||||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||||
| ('gelu', QuickGELU()), | |||||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||||
| self.ln_2 = LayerNorm(d_model) | |||||
| self.attn_mask = attn_mask | |||||
| self.drop_path = DropPath( | |||||
| drop_path) if drop_path > 0. else nn.Identity() | |||||
| def attention(self, x: torch.Tensor): | |||||
| self.attn_mask = self.attn_mask.to( | |||||
| dtype=x.dtype, | |||||
| device=x.device) if self.attn_mask is not None else None | |||||
| return self.attn( | |||||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||||
| def forward(self, x: torch.Tensor): | |||||
| x = x + self.drop_path(self.attention(self.ln_1(x))) | |||||
| x = x + self.drop_path(self.mlp(self.ln_2(x))) | |||||
| return x | |||||
| class Transformer(nn.Module): | |||||
| def __init__(self, | |||||
| width: int, | |||||
| layers: int, | |||||
| heads: int, | |||||
| attn_mask: torch.Tensor = None, | |||||
| drop_path_rate=0.): | |||||
| super().__init__() | |||||
| self.width = width | |||||
| self.layers = layers | |||||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers) | |||||
| ] # stochastic depth decay rule | |||||
| self.resblocks = nn.Sequential(*[ | |||||
| ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) | |||||
| for i in range(layers) | |||||
| ]) | |||||
| def forward(self, x: torch.Tensor): | |||||
| return self.resblocks(x) | |||||
| class Attention(nn.Module): | |||||
| def __init__(self, | |||||
| dim, | |||||
| num_heads=8, | |||||
| qkv_bias=False, | |||||
| qk_scale=None, | |||||
| attn_drop=0., | |||||
| proj_drop=0.): | |||||
| super().__init__() | |||||
| self.num_heads = num_heads | |||||
| head_dim = dim // num_heads | |||||
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |||||
| self.scale = qk_scale or head_dim**-0.5 | |||||
| self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) | |||||
| self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) | |||||
| self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| def forward(self, q, k, v): | |||||
| B, N, C = q.shape | |||||
| assert k.shape == v.shape | |||||
| B, M, C = k.shape | |||||
| q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) | |||||
| k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) | |||||
| v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) | |||||
| attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale | |||||
| attn = attn.softmax(dim=-1) | |||||
| x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) | |||||
| x = self.proj(x) | |||||
| x = self.proj_drop(x) | |||||
| return x | |||||
| class TransformerDecoderLayer(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_model, | |||||
| nhead, | |||||
| dropout=0.1, | |||||
| ): | |||||
| super().__init__() | |||||
| self.self_attn = Attention(d_model, nhead, proj_drop=dropout) | |||||
| self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) | |||||
| self.norm1 = nn.LayerNorm(d_model) | |||||
| self.norm2 = nn.LayerNorm(d_model) | |||||
| self.norm3 = nn.LayerNorm(d_model) | |||||
| self.dropout = nn.Dropout(dropout) | |||||
| self.mlp = nn.Sequential( | |||||
| nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), | |||||
| nn.Linear(d_model * 4, d_model)) | |||||
| def forward(self, x, mem): | |||||
| q = k = v = self.norm1(x) | |||||
| x = x + self.self_attn(q, k, v) | |||||
| q = self.norm2(x) | |||||
| x = x + self.cross_attn(q, mem, mem) | |||||
| x = x + self.dropout(self.mlp(self.norm3(x))) | |||||
| return x | |||||
| class CLIPVisionTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| input_resolution=224, | |||||
| patch_size=32, | |||||
| width=768, | |||||
| layers=12, | |||||
| heads=12, | |||||
| output_dim=512, | |||||
| drop_path_rate=0.0, | |||||
| out_indices=[3, 5, 7, 11], | |||||
| pretrained=None, | |||||
| get_embeddings=False, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.pretrained = pretrained | |||||
| self.input_resolution = input_resolution | |||||
| self.output_dim = output_dim | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_channels=3, | |||||
| out_channels=width, | |||||
| kernel_size=patch_size, | |||||
| stride=patch_size, | |||||
| bias=False) | |||||
| scale = width**-0.5 | |||||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||||
| (input_resolution // patch_size)**2 + 1, width)) | |||||
| self.spatial_size = input_resolution // patch_size | |||||
| self.ln_pre = LayerNorm(width) | |||||
| self.get_embeddings = get_embeddings | |||||
| self.transformer = Transformer( | |||||
| width, layers, heads, drop_path_rate=drop_path_rate) | |||||
| self.out_indices = out_indices | |||||
| if get_embeddings: | |||||
| self.ln_post = LayerNorm(width) | |||||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||||
| embed_dim = width | |||||
| if patch_size == 16: | |||||
| self.fpn1 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.ConvTranspose2d( | |||||
| embed_dim, embed_dim, kernel_size=2, stride=2), | |||||
| nn.SyncBatchNorm(embed_dim), | |||||
| nn.GELU(), | |||||
| nn.ConvTranspose2d( | |||||
| embed_dim, embed_dim, kernel_size=2, stride=2), | |||||
| ) | |||||
| self.fpn2 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.ConvTranspose2d( | |||||
| embed_dim, embed_dim, kernel_size=2, stride=2), | |||||
| ) | |||||
| self.fpn3 = nn.GroupNorm(1, embed_dim) | |||||
| self.fpn4 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.MaxPool2d(kernel_size=2, stride=2)) | |||||
| elif patch_size == 8: | |||||
| self.fpn1 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.ConvTranspose2d( | |||||
| embed_dim, embed_dim, kernel_size=2, stride=2), | |||||
| ) | |||||
| self.fpn2 = nn.GroupNorm(1, embed_dim) | |||||
| self.fpn3 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||||
| ) | |||||
| self.fpn4 = nn.Sequential( | |||||
| nn.GroupNorm(1, embed_dim), | |||||
| nn.MaxPool2d(kernel_size=4, stride=4), | |||||
| ) | |||||
| def init_weights(self, pretrained=None): | |||||
| pretrained = pretrained or self.pretrained | |||||
| if isinstance(pretrained, str): | |||||
| checkpoint = torch.jit.load( | |||||
| pretrained, map_location='cpu').float().state_dict() | |||||
| state_dict = {} | |||||
| for k in checkpoint.keys(): | |||||
| if k.startswith('visual.'): | |||||
| new_k = k.replace('visual.', '') | |||||
| state_dict[new_k] = checkpoint[k] | |||||
| if 'positional_embedding' in state_dict.keys(): | |||||
| if self.positional_embedding.shape != state_dict[ | |||||
| 'positional_embedding'].shape: | |||||
| print( | |||||
| f'Resize the pos_embed shape from {state_dict["positional_embedding"].shape} to' | |||||
| f' {self.positional_embedding.shape}') | |||||
| cls_pos = state_dict['positional_embedding'][0:1, :] | |||||
| spatial_pos = F.interpolate( | |||||
| state_dict['positional_embedding'][1:, ].reshape( | |||||
| 1, 14, 14, 768).permute(0, 3, 1, 2), | |||||
| size=(self.spatial_size, self.spatial_size), | |||||
| mode='bilinear') | |||||
| spatial_pos = spatial_pos.reshape( | |||||
| 768, | |||||
| self.spatial_size * self.spatial_size).permute(1, 0) | |||||
| positional_embedding = torch.cat([cls_pos, spatial_pos], | |||||
| dim=0) | |||||
| state_dict['positional_embedding'] = positional_embedding | |||||
| assert self.positional_embedding.shape == state_dict[ | |||||
| 'positional_embedding'].shape | |||||
| u, w = self.load_state_dict(state_dict, False) | |||||
| print(u, w, 'are misaligned params in vision transformer') | |||||
| def forward(self, x: torch.Tensor): | |||||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||||
| B, C, H, W = x.shape | |||||
| x = x.reshape(x.shape[0], x.shape[1], | |||||
| -1) # shape = [*, width, grid ** 2] | |||||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||||
| x1 = self.class_embedding.to(x.dtype) | |||||
| x2 = torch.zeros( | |||||
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||||
| x = torch.cat([x1 + x2, x], dim=1) | |||||
| pos = self.positional_embedding.to(x.dtype) | |||||
| cls_pos = pos[0, :] + self.class_embedding.to(x.dtype) | |||||
| spatial_pos = F.interpolate( | |||||
| pos[1:, ].reshape(1, self.spatial_size, self.spatial_size, | |||||
| C).permute(0, 3, 1, 2), | |||||
| size=(H, W), | |||||
| mode='bilinear') | |||||
| spatial_pos = spatial_pos.reshape(1, C, H * W).permute(0, 2, 1) | |||||
| pos = torch.cat([cls_pos.reshape(1, 1, C), spatial_pos], dim=1) | |||||
| x = x + pos | |||||
| x = self.ln_pre(x) | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| gradientcheckpoint = False | |||||
| features = [] | |||||
| for i, blk in enumerate(self.transformer.resblocks): | |||||
| if gradientcheckpoint: | |||||
| x = checkpoint.checkpoint(blk, x) | |||||
| else: | |||||
| x = blk(x) | |||||
| if i in self.out_indices: | |||||
| xp = x.permute(1, 0, 2)[:, | |||||
| 1:, :].permute(0, 2, | |||||
| 1).reshape(B, -1, H, W) | |||||
| features.append(xp.contiguous()) | |||||
| ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] | |||||
| for i in range(len(features)): | |||||
| features[i] = ops[i](features[i]) | |||||
| if self.get_embeddings: | |||||
| x = x.permute(1, 0, 2) | |||||
| x = self.ln_post(x) | |||||
| x = x @ self.proj | |||||
| global_embedding = x[:, 0] | |||||
| visual_embedding = x[:, 1:].reshape(B, H, W, | |||||
| -1).permute(0, 3, 1, | |||||
| 2) # B C H W | |||||
| features.append([global_embedding, visual_embedding]) | |||||
| return tuple(features) | |||||
| class CLIPTextEncoder(nn.Module): | |||||
| def __init__(self, | |||||
| context_length=77, | |||||
| vocab_size=49408, | |||||
| transformer_width=512, | |||||
| transformer_heads=8, | |||||
| transformer_layers=12, | |||||
| embed_dim=1024, | |||||
| out_dim=256, | |||||
| pretrained=None, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.pretrained = pretrained | |||||
| self.context_length = context_length | |||||
| self.transformer = Transformer( | |||||
| width=transformer_width, | |||||
| layers=transformer_layers, | |||||
| heads=transformer_heads, | |||||
| attn_mask=self.build_attention_mask()) | |||||
| self.vocab_size = vocab_size | |||||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.empty(self.context_length, transformer_width)) | |||||
| self.ln_final = LayerNorm(transformer_width) | |||||
| self.text_projection = nn.Parameter( | |||||
| torch.empty(transformer_width, embed_dim)) | |||||
| def init_weights(self, pretrained=None): | |||||
| pretrained = pretrained or self.pretrained | |||||
| if isinstance(pretrained, str): | |||||
| checkpoint = torch.jit.load( | |||||
| pretrained, map_location='cpu').float().state_dict() | |||||
| state_dict = {} | |||||
| for k in checkpoint.keys(): | |||||
| if k.startswith('transformer.'): | |||||
| state_dict[k] = checkpoint[k] | |||||
| if k == 'positional_embedding' or k == 'text_projection' or k.startswith( | |||||
| 'token_embedding') or k.startswith('ln_final'): | |||||
| if k == 'positional_embedding' and checkpoint[k].size( | |||||
| 0) > self.context_length: | |||||
| checkpoint[k] = checkpoint[k][:self.context_length] | |||||
| print('positional_embedding is tuncated from 77 to', | |||||
| self.context_length) | |||||
| state_dict[k] = checkpoint[k] | |||||
| u, w = self.load_state_dict(state_dict, False) | |||||
| print(u, w, 'are misaligned params in text encoder') | |||||
| def build_attention_mask(self): | |||||
| # lazily create causal attention mask, with full attention between the vision tokens | |||||
| # pytorch uses additive attention mask; fill with -inf | |||||
| mask = torch.empty(self.context_length, self.context_length) | |||||
| mask.fill_(float('-inf')) | |||||
| mask.triu_(1) # zero out the lower diagonal | |||||
| return mask | |||||
| def forward(self, text): | |||||
| x = self.token_embedding(text) | |||||
| x = x + self.positional_embedding | |||||
| x = x.permute(1, 0, 2) | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) | |||||
| x = self.ln_final(x) | |||||
| x = x[torch.arange(x.shape[0]), | |||||
| text.argmax(dim=-1), ...] @ self.text_projection | |||||
| return x | |||||
| class CLIPTextContextEncoder(nn.Module): | |||||
| def __init__(self, | |||||
| context_length=22, | |||||
| vocab_size=49408, | |||||
| transformer_width=512, | |||||
| transformer_heads=8, | |||||
| transformer_layers=12, | |||||
| embed_dim=1024, | |||||
| out_dim=256, | |||||
| pretrained=None, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.pretrained = pretrained | |||||
| self.context_length = context_length | |||||
| self.transformer = Transformer( | |||||
| width=transformer_width, | |||||
| layers=transformer_layers, | |||||
| heads=transformer_heads, | |||||
| attn_mask=self.build_attention_mask()) | |||||
| self.embed_dim = embed_dim | |||||
| self.vocab_size = vocab_size | |||||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.empty(self.context_length, transformer_width)) | |||||
| self.ln_final = LayerNorm(transformer_width) | |||||
| self.text_projection = nn.Parameter( | |||||
| torch.empty(transformer_width, embed_dim)) | |||||
| def init_weights(self, pretrained=None): | |||||
| pretrained = pretrained or self.pretrained | |||||
| if isinstance(pretrained, str): | |||||
| checkpoint = torch.jit.load( | |||||
| pretrained, map_location='cpu').float().state_dict() | |||||
| state_dict = {} | |||||
| for k in checkpoint.keys(): | |||||
| if k.startswith('transformer.'): | |||||
| state_dict[k] = checkpoint[k] | |||||
| if k == 'positional_embedding' or k == 'text_projection' or k.startswith( | |||||
| 'token_embedding') or k.startswith('ln_final'): | |||||
| if k == 'positional_embedding' and checkpoint[k].size( | |||||
| 0) > self.context_length: | |||||
| checkpoint[k] = checkpoint[k][:self.context_length] | |||||
| print('positional_embedding is tuncated from 77 to', | |||||
| self.context_length) | |||||
| state_dict[k] = checkpoint[k] | |||||
| u, w = self.load_state_dict(state_dict, False) | |||||
| print(u, w, 'are misaligned params in text encoder') | |||||
| def build_attention_mask(self): | |||||
| # lazily create causal attention mask, with full attention between the vision tokens | |||||
| # pytorch uses additive attention mask; fill with -inf | |||||
| mask = torch.empty(self.context_length, self.context_length) | |||||
| mask.fill_(float('-inf')) | |||||
| mask.triu_(1) # zero out the lower diagonal | |||||
| return mask | |||||
| def forward(self, text, context=None): | |||||
| x_text = self.token_embedding(text) # n_clas, n_text, C | |||||
| K, N1, C = x_text.shape # 150类 * 5??? * 512 | |||||
| B, N2, C = context.shape # 1 * 8 * 512 | |||||
| eos_indx = text.argmax(dim=-1) + N2 | |||||
| eos_indx = eos_indx.reshape(1, K).expand(B, K).reshape(-1) | |||||
| x_text = x_text.reshape(1, K, N1, C).expand(B, K, N1, C) | |||||
| context = context.reshape(B, 1, N2, C).expand(B, K, N2, C) | |||||
| x = torch.cat([x_text[:, :, 0:1], context, x_text[:, :, 1:]], | |||||
| dim=2).reshape(B * K, N1 + N2, C) | |||||
| x = x + self.positional_embedding | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) # LND -> NLD | |||||
| x = self.ln_final(x) | |||||
| x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection | |||||
| x = x.reshape(B, K, self.embed_dim) | |||||
| return x | |||||
| class ContextDecoder(nn.Module): | |||||
| def __init__(self, | |||||
| transformer_width=256, | |||||
| transformer_heads=4, | |||||
| transformer_layers=6, | |||||
| visual_dim=1024, | |||||
| dropout=0.1, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| self.memory_proj = nn.Sequential( | |||||
| nn.LayerNorm(visual_dim), | |||||
| nn.Linear(visual_dim, transformer_width), | |||||
| nn.LayerNorm(transformer_width), | |||||
| ) | |||||
| self.text_proj = nn.Sequential( | |||||
| nn.LayerNorm(visual_dim), | |||||
| nn.Linear(visual_dim, transformer_width), | |||||
| ) | |||||
| self.decoder = nn.ModuleList([ | |||||
| TransformerDecoderLayer(transformer_width, transformer_heads, | |||||
| dropout) for _ in range(transformer_layers) | |||||
| ]) | |||||
| self.out_proj = nn.Sequential( | |||||
| nn.LayerNorm(transformer_width), | |||||
| nn.Linear(transformer_width, visual_dim)) | |||||
| self.apply(self._init_weights) | |||||
| def _init_weights(self, m): | |||||
| if isinstance(m, nn.Linear): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.LayerNorm): | |||||
| nn.init.constant_(m.bias, 0) | |||||
| nn.init.constant_(m.weight, 1.0) | |||||
| def forward(self, text, visual): | |||||
| B, N, C = visual.shape | |||||
| visual = self.memory_proj(visual) | |||||
| x = self.text_proj(text) | |||||
| for layer in self.decoder: | |||||
| x = layer(x, visual) | |||||
| return self.out_proj(x) | |||||
| @@ -0,0 +1,217 @@ | |||||
| """ FPNneck | |||||
| Base modules are adapted from https://github.com/open-mmlab/mmcv/, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| https://github.com/open-mmlab/mmsegmentation/, | |||||
| originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, | |||||
| and adapted from https://github.com/raoyongming/DenseCLIP/, | |||||
| originally MIT License, Copyright (c) 2022 Rao, Yongming. | |||||
| """ | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from mmcv.cnn import ConvModule | |||||
| from timm.models.layers import drop, drop_path, trunc_normal_ | |||||
| from .common import resize | |||||
| class FPN(nn.Module): | |||||
| """Feature Pyramid Network. | |||||
| This neck is the implementation of `Feature Pyramid Networks for Object | |||||
| Detection <https://arxiv.org/abs/1612.03144>`_. | |||||
| Args: | |||||
| in_channels (list[int]): Number of input channels per scale. | |||||
| out_channels (int): Number of output channels (used at each scale). | |||||
| num_outs (int): Number of output scales. | |||||
| start_level (int): Index of the start input backbone level used to | |||||
| build the feature pyramid. Default: 0. | |||||
| end_level (int): Index of the end input backbone level (exclusive) to | |||||
| build the feature pyramid. Default: -1, which means the last level. | |||||
| add_extra_convs (bool | str): If bool, it decides whether to add conv | |||||
| layers on top of the original feature maps. Default to False. | |||||
| If True, its actual mode is specified by `extra_convs_on_inputs`. | |||||
| If str, it specifies the source feature map of the extra convs. | |||||
| Only the following options are allowed | |||||
| - 'on_input': Last feat map of neck inputs (i.e. backbone feature). | |||||
| - 'on_lateral': Last feature map after lateral convs. | |||||
| - 'on_output': The last output feature map after fpn convs. | |||||
| extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs | |||||
| on the original feature from the backbone. If True, | |||||
| it is equivalent to `add_extra_convs='on_input'`. If False, it is | |||||
| equivalent to set `add_extra_convs='on_output'`. Default to True. | |||||
| relu_before_extra_convs (bool): Whether to apply relu before the extra | |||||
| conv. Default: False. | |||||
| no_norm_on_lateral (bool): Whether to apply norm on lateral. | |||||
| Default: False. | |||||
| conv_cfg (dict): Config dict for convolution layer. Default: None. | |||||
| norm_cfg (dict): Config dict for normalization layer. Default: None. | |||||
| act_cfg (dict): Config dict for activation layer in ConvModule. | |||||
| Default: None. | |||||
| upsample_cfg (dict): Config dict for interpolate layer. | |||||
| Default: dict(mode='nearest'). | |||||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| num_outs, | |||||
| start_level=0, | |||||
| end_level=-1, | |||||
| add_extra_convs=False, | |||||
| extra_convs_on_inputs=False, | |||||
| relu_before_extra_convs=False, | |||||
| no_norm_on_lateral=False, | |||||
| conv_cfg=None, | |||||
| norm_cfg=None, | |||||
| act_cfg=None, | |||||
| upsample_cfg=dict(mode='nearest')): | |||||
| super(FPN, self).__init__() | |||||
| assert isinstance(in_channels, list) | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.num_ins = len(in_channels) | |||||
| self.num_outs = num_outs | |||||
| self.relu_before_extra_convs = relu_before_extra_convs | |||||
| self.no_norm_on_lateral = no_norm_on_lateral | |||||
| self.fp16_enabled = False | |||||
| self.upsample_cfg = upsample_cfg.copy() | |||||
| if end_level == -1: | |||||
| self.backbone_end_level = self.num_ins | |||||
| assert num_outs >= self.num_ins - start_level | |||||
| else: | |||||
| # if end_level < inputs, no extra level is allowed | |||||
| self.backbone_end_level = end_level | |||||
| assert end_level <= len(in_channels) | |||||
| assert num_outs == end_level - start_level | |||||
| self.start_level = start_level | |||||
| self.end_level = end_level | |||||
| self.add_extra_convs = add_extra_convs | |||||
| assert isinstance(add_extra_convs, (str, bool)) | |||||
| if isinstance(add_extra_convs, str): | |||||
| # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' | |||||
| assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') | |||||
| elif add_extra_convs: # True | |||||
| if extra_convs_on_inputs: | |||||
| # For compatibility with previous release | |||||
| # TODO: deprecate `extra_convs_on_inputs` | |||||
| self.add_extra_convs = 'on_input' | |||||
| else: | |||||
| self.add_extra_convs = 'on_output' | |||||
| self.lateral_convs = nn.ModuleList() | |||||
| self.fpn_convs = nn.ModuleList() | |||||
| for i in range(self.start_level, self.backbone_end_level): | |||||
| l_conv = ConvModule( | |||||
| in_channels[i], | |||||
| out_channels, | |||||
| 1, | |||||
| conv_cfg=conv_cfg, | |||||
| norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, | |||||
| act_cfg=act_cfg, | |||||
| inplace=False) | |||||
| fpn_conv = ConvModule( | |||||
| out_channels, | |||||
| out_channels, | |||||
| 3, | |||||
| padding=1, | |||||
| conv_cfg=conv_cfg, | |||||
| norm_cfg=norm_cfg, | |||||
| act_cfg=act_cfg, | |||||
| inplace=False) | |||||
| self.lateral_convs.append(l_conv) | |||||
| self.fpn_convs.append(fpn_conv) | |||||
| # add extra conv layers (e.g., RetinaNet) | |||||
| extra_levels = num_outs - self.backbone_end_level + self.start_level | |||||
| if self.add_extra_convs and extra_levels >= 1: | |||||
| for i in range(extra_levels): | |||||
| if i == 0 and self.add_extra_convs == 'on_input': | |||||
| in_channels = self.in_channels[self.backbone_end_level - 1] | |||||
| else: | |||||
| in_channels = out_channels | |||||
| extra_fpn_conv = ConvModule( | |||||
| in_channels, | |||||
| out_channels, | |||||
| 3, | |||||
| stride=2, | |||||
| padding=1, | |||||
| conv_cfg=conv_cfg, | |||||
| norm_cfg=norm_cfg, | |||||
| act_cfg=act_cfg, | |||||
| inplace=False) | |||||
| self.fpn_convs.append(extra_fpn_conv) | |||||
| self.apply(self._init_weights) | |||||
| def forward(self, inputs): | |||||
| assert len(inputs) == len(self.in_channels) | |||||
| # build laterals | |||||
| laterals = [ | |||||
| lateral_conv(inputs[i + self.start_level]) | |||||
| for i, lateral_conv in enumerate(self.lateral_convs) | |||||
| ] | |||||
| # build top-down path | |||||
| used_backbone_levels = len(laterals) | |||||
| for i in range(used_backbone_levels - 1, 0, -1): | |||||
| # In some cases, fixing `scale factor` (e.g. 2) is preferred, but | |||||
| # it cannot co-exist with `size` in `F.interpolate`. | |||||
| if 'scale_factor' in self.upsample_cfg: | |||||
| laterals[i - 1] = laterals[i - 1] + resize( | |||||
| laterals[i], **self.upsample_cfg) | |||||
| else: | |||||
| prev_shape = laterals[i - 1].shape[2:] | |||||
| laterals[i - 1] = laterals[i - 1] + resize( | |||||
| laterals[i], size=prev_shape, **self.upsample_cfg) | |||||
| # build outputs | |||||
| # part 1: from original levels | |||||
| outs = [ | |||||
| self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) | |||||
| ] | |||||
| # part 2: add extra levels | |||||
| if self.num_outs > len(outs): | |||||
| # use max pool to get more levels on top of outputs | |||||
| # (e.g., Faster R-CNN, Mask R-CNN) | |||||
| if not self.add_extra_convs: | |||||
| for i in range(self.num_outs - used_backbone_levels): | |||||
| outs.append(F.max_pool2d(outs[-1], 1, stride=2)) | |||||
| # add conv layers on top of original feature maps (RetinaNet) | |||||
| else: | |||||
| if self.add_extra_convs == 'on_input': | |||||
| extra_source = inputs[self.backbone_end_level - 1] | |||||
| elif self.add_extra_convs == 'on_lateral': | |||||
| extra_source = laterals[-1] | |||||
| elif self.add_extra_convs == 'on_output': | |||||
| extra_source = outs[-1] | |||||
| else: | |||||
| raise NotImplementedError | |||||
| outs.append(self.fpn_convs[used_backbone_levels](extra_source)) | |||||
| for i in range(used_backbone_levels + 1, self.num_outs): | |||||
| if self.relu_before_extra_convs: | |||||
| outs.append(self.fpn_convs[i](F.relu(outs[-1]))) | |||||
| else: | |||||
| outs.append(self.fpn_convs[i](outs[-1])) | |||||
| return tuple(outs) | |||||
| def _init_weights(self, m): | |||||
| if isinstance(m, nn.Linear): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.LayerNorm): | |||||
| nn.init.constant_(m.bias, 0) | |||||
| nn.init.constant_(m.weight, 1.0) | |||||
| elif isinstance(m, nn.Conv2d): | |||||
| nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') | |||||
| if m.bias is not None: | |||||
| nn.init.constant_(m.bias.data, 0) | |||||
| @@ -0,0 +1,157 @@ | |||||
| """ | |||||
| Base modules are adapted from https://github.com/open-mmlab/mmcv/, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| https://github.com/open-mmlab/mmsegmentation/, | |||||
| originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, | |||||
| and adapted from https://github.com/raoyongming/DenseCLIP/, | |||||
| originally MIT License, Copyright (c) 2022 Rao, Yongming. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .head_fpn import FPNHead | |||||
| from .models import (CLIPTextContextEncoder, CLIPVisionTransformer, | |||||
| ContextDecoder) | |||||
| from .neck_fpn import FPN | |||||
| from .utils import SimpleTokenizer, tokenize | |||||
| class SHOPSEG(nn.Module): | |||||
| """Encoder Decoder segmentors. | |||||
| EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. | |||||
| Note that auxiliary_head is only used for deep supervision during training, | |||||
| which could be dumped during inference. | |||||
| """ | |||||
| def __init__(self, | |||||
| model_dir, | |||||
| context_length=22, | |||||
| context_feature='attention', | |||||
| score_concat_index=2, | |||||
| tau=0.07, | |||||
| token_embed_dim=512, | |||||
| text_dim=512, | |||||
| **args): | |||||
| super(SHOPSEG, self).__init__() | |||||
| self.model_dir = model_dir | |||||
| self.tokenizer = SimpleTokenizer(model_dir | |||||
| + '/bpe_simple_vocab_16e6.txt.gz') | |||||
| backbone = CLIPVisionTransformer( | |||||
| input_resolution=1024, | |||||
| patch_size=16, | |||||
| width=768, | |||||
| layers=12, | |||||
| output_dim=512, | |||||
| drop_path_rate=0.1, | |||||
| pretrained=False, | |||||
| get_embeddings=True) | |||||
| text_encoder = CLIPTextContextEncoder( | |||||
| context_length=30, | |||||
| vocab_size=49408, | |||||
| transformer_width=512, | |||||
| transformer_heads=8, | |||||
| transformer_layers=12, | |||||
| embed_dim=512, | |||||
| pretrained=False) | |||||
| context_decoder = ContextDecoder( | |||||
| transformer_width=256, | |||||
| transformer_heads=4, | |||||
| transformer_layers=3, | |||||
| visual_dim=512, | |||||
| dropout=0.1) | |||||
| neck = FPN( | |||||
| in_channels=[768, 768, 768 + 2, 768], out_channels=256, num_outs=4) | |||||
| head_fpd = FPNHead(channels=256, num_classes=2) | |||||
| self.backbone = backbone | |||||
| self.text_encoder = text_encoder | |||||
| self.context_decoder = context_decoder | |||||
| self.context_length = context_length | |||||
| self.score_concat_index = score_concat_index | |||||
| self.context_feature = context_feature | |||||
| self.tau = tau | |||||
| context_length = self.text_encoder.context_length - self.context_length | |||||
| self.contexts = nn.Parameter( | |||||
| torch.randn(1, context_length, token_embed_dim)) | |||||
| nn.init.trunc_normal_(self.contexts) | |||||
| self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4) | |||||
| self.neck = neck | |||||
| self.head_fpn = head_fpd | |||||
| self.tau = 0.07 | |||||
| def encode_text(self, text, context_length): | |||||
| output = tokenize(self.tokenizer, text, context_length, True) | |||||
| return output | |||||
| def extract_feat(self, img): | |||||
| """Extract features from images.""" | |||||
| x = self.backbone(img) | |||||
| return x | |||||
| def after_extract_feat(self, x, name_list): | |||||
| x_orig = list(x[0:4]) | |||||
| global_feat, visual_embeddings = x[4] | |||||
| B, C, H, W = visual_embeddings.shape | |||||
| if self.context_feature == 'attention': | |||||
| x1 = global_feat.reshape(B, C, 1) | |||||
| x2 = visual_embeddings.reshape(B, C, H * W) | |||||
| visual_context = torch.cat([x1, x2], dim=2).permute(0, 2, 1) | |||||
| texts = torch.cat([ | |||||
| self.encode_text(c, context_length=self.context_length) | |||||
| for c in name_list | |||||
| ]) | |||||
| x1 = texts.to(global_feat.device) | |||||
| x1 = self.text_encoder(x1, self.contexts) | |||||
| text_embeddings = x1.expand(B, -1, -1) | |||||
| # update text_embeddings by visual_context! | |||||
| # (B, 1, C) | |||||
| text_diff = self.context_decoder(text_embeddings, visual_context) | |||||
| # (B, K, C) | |||||
| text_embeddings = text_embeddings + self.gamma * text_diff | |||||
| # compute score map and concat | |||||
| B, K, C = text_embeddings.shape | |||||
| visual_embeddings = F.normalize(visual_embeddings, dim=1, p=2) | |||||
| text = F.normalize(text_embeddings, dim=2, p=2) | |||||
| score_map_list = [] | |||||
| bsz = B | |||||
| for i in range(bsz): | |||||
| ind = 2 * i | |||||
| sub_text = torch.cat( | |||||
| [text[i:i + 1, ind:ind + 1], text[i:i + 1, ind + 1:ind + 2]], | |||||
| dim=1) # 1 * 2 * h * w | |||||
| sub_score_map = torch.einsum('bchw,bkc->bkhw', | |||||
| visual_embeddings[i:i + 1], | |||||
| sub_text) # 1 * 2 * h * w | |||||
| score_map_list.append(sub_score_map) | |||||
| score_map = torch.cat(score_map_list, dim=0) # b * 2 * h * w | |||||
| x_orig[self.score_concat_index] = torch.cat( | |||||
| [x_orig[self.score_concat_index], score_map], dim=1) | |||||
| return x_orig, score_map | |||||
| def forward(self, img, text_list=None): | |||||
| if text_list is None: | |||||
| bsz = img.size()[0] | |||||
| text_list = ['foregeound'] * bsz | |||||
| x = self.extract_feat(img) | |||||
| _x_orig = [x[i] for i in range(4)] | |||||
| name_list = [] | |||||
| for name in text_list: | |||||
| name_list.append('others') | |||||
| name_list.append(name[0:20]) | |||||
| x_orig, score_map = self.after_extract_feat(x, name_list) | |||||
| x_orig = list(self.neck(x_orig)) | |||||
| _x_orig = x_orig | |||||
| pred = self.head_fpn(_x_orig) | |||||
| return pred | |||||
| @@ -0,0 +1,115 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.shop_segmentation import SHOPSEG | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['ShopSegmentation'] | |||||
| @MODELS.register_module( | |||||
| Tasks.shop_segmentation, module_name=Models.shop_segmentation) | |||||
| class ShopSegmentation(TorchModel): | |||||
| """ shop segmentation model. | |||||
| """ | |||||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
| super().__init__( | |||||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
| self.model = SHOPSEG(model_dir=model_dir) | |||||
| pretrained_params = torch.load('{}/{}'.format( | |||||
| model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
| self.model.load_state_dict(pretrained_params) | |||||
| self.model.eval() | |||||
| self.device_id = device_id | |||||
| if self.device_id >= 0 and torch.cuda.is_available(): | |||||
| self.model.to('cuda:{}'.format(self.device_id)) | |||||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||||
| else: | |||||
| self.device_id = -1 | |||||
| logger.info('Use CPU for inference') | |||||
| def preprocess(self, img, size=1024): | |||||
| mean = [0.48145466, 0.4578275, 0.40821073] | |||||
| std = [0.26862954, 0.26130258, 0.27577711] | |||||
| h, w, c = img.shape | |||||
| max_hw = max(h, w) | |||||
| ratio = 1.0 * size / max_hw | |||||
| crop_h, crop_w = int(ratio * h), int(ratio * w) | |||||
| pil_img = Image.fromarray(img) | |||||
| pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) | |||||
| np_img = np.array(pil_img, dtype=np.float32) / 255. | |||||
| for j in range(3): | |||||
| np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] | |||||
| img_pad = np.zeros((size, size, 3), dtype=np.float32) | |||||
| img_pad[:crop_h, :crop_w] = np_img | |||||
| img_pad = torch.from_numpy(img_pad).permute(2, 0, | |||||
| 1).unsqueeze(0).float() | |||||
| return img_pad, h, w, crop_h, crop_w | |||||
| def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): | |||||
| output = np.clip(tensors * 255., a_min=0, a_max=255.) | |||||
| crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) | |||||
| pil_output = Image.fromarray(crop_output) | |||||
| pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) | |||||
| np_output = np.array(pil_output, dtype=np.uint8) | |||||
| np_output[np_output < 128] = 0 | |||||
| np_output[np_output >= 128] = 255 | |||||
| np_output = np.uint8(np_output) | |||||
| return np_output | |||||
| def forward(self, image): | |||||
| """ | |||||
| image should be numpy array, dtype=np.uint8, shape: height*width*3 | |||||
| """ | |||||
| image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( | |||||
| image, size=1024) | |||||
| pred = self.inference(image_tensor) | |||||
| msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=1024) | |||||
| outputs = {OutputKeys.MASKS: msk} | |||||
| return outputs | |||||
| def inference(self, image): | |||||
| """ | |||||
| image should be tensor, 1 * 3 * 1024 * 1024 | |||||
| """ | |||||
| with torch.no_grad(): | |||||
| if self.device_id == -1: | |||||
| output = self.model(image) | |||||
| else: | |||||
| device = torch.device('cuda', self.device_id) | |||||
| output = self.model(image.to(device)) | |||||
| output = F.interpolate(output, size=(1024, 1024), mode='bilinear') | |||||
| output = F.softmax(output, dim=1) | |||||
| output = torch.argmax(output, dim=1) | |||||
| output = output[0] | |||||
| if self.device_id == -1: | |||||
| pred = output.data.numpy() | |||||
| else: | |||||
| pred = output.data.cpu().numpy() | |||||
| del output | |||||
| return pred | |||||
| @@ -0,0 +1,199 @@ | |||||
| """ CLIP Tokenizer | |||||
| Adapted from https://github.com/openai/CLIP. | |||||
| Originally MIT License, Copyright (c) 2021 OpenAI. | |||||
| """ | |||||
| import gzip | |||||
| import html | |||||
| import os | |||||
| from functools import lru_cache | |||||
| from typing import Any, List, Union | |||||
| import ftfy | |||||
| import regex as re | |||||
| import torch | |||||
| @lru_cache() | |||||
| def default_bpe(): | |||||
| return os.path.join( | |||||
| os.path.dirname(os.path.abspath(__file__)), | |||||
| 'bpe_simple_vocab_16e6.txt.gz') | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| def basic_clean(text): | |||||
| text = ftfy.fix_text(text) | |||||
| text = html.unescape(html.unescape(text)) | |||||
| return text.strip() | |||||
| def whitespace_clean(text): | |||||
| text = re.sub(r'\s+', ' ', text) | |||||
| text = text.strip() | |||||
| return text | |||||
| class SimpleTokenizer(object): | |||||
| def __init__(self, bpe_path: str = default_bpe()): | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||||
| merges = merges[1:49152 - 256 - 2 + 1] | |||||
| merges = [tuple(merge.split()) for merge in merges] | |||||
| vocab = list(bytes_to_unicode().values()) | |||||
| vocab = vocab + [v + '</w>' for v in vocab] | |||||
| for merge in merges: | |||||
| vocab.append(''.join(merge)) | |||||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||||
| self.cache = { | |||||
| '<|startoftext|>': '<|startoftext|>', | |||||
| '<|endoftext|>': '<|endoftext|>' | |||||
| } | |||||
| self.pat = re.compile( | |||||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||||
| re.IGNORECASE) | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token + '</w>' | |||||
| error_list = [] | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except Exception as err: | |||||
| error_list.append(err) | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def encode(self, text): | |||||
| bpe_tokens = [] | |||||
| text = whitespace_clean(basic_clean(text)).lower() | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| bpe_tokens.extend(self.encoder[bpe_token] | |||||
| for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||||
| return text | |||||
| def tokenize(tokenizer, | |||||
| texts, | |||||
| context_length: int = 77, | |||||
| truncate: bool = False) -> torch.LongTensor: | |||||
| """ | |||||
| Returns the tokenized representation of given input string(s) | |||||
| Parameters | |||||
| ---------- | |||||
| texts : Union[str, List[str]] | |||||
| An input string or a list of input strings to tokenize | |||||
| context_length : int | |||||
| The context length to use; all CLIP models use 77 as the context length | |||||
| truncate: bool | |||||
| Whether to truncate the text in case its encoding is longer than the context length | |||||
| Returns | |||||
| ------- | |||||
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] | |||||
| """ | |||||
| if isinstance(texts, str): | |||||
| texts = [texts] | |||||
| sot_token = tokenizer.encoder['<|startoftext|>'] | |||||
| eot_token = tokenizer.encoder['<|endoftext|>'] | |||||
| all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] | |||||
| for text in texts] | |||||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
| for i, tokens in enumerate(all_tokens): | |||||
| if len(tokens) > context_length: | |||||
| if truncate: | |||||
| tokens = tokens[:context_length] | |||||
| tokens[-1] = eot_token | |||||
| else: | |||||
| raise RuntimeError( | |||||
| f'Input {texts[i]} is too long for context length {context_length}' | |||||
| ) | |||||
| result[i, :len(tokens)] = torch.tensor(tokens) | |||||
| return result | |||||
| @@ -259,7 +259,13 @@ TASK_OUTPUTS = { | |||||
| # ] | # ] | ||||
| # } | # } | ||||
| Tasks.text_driven_segmentation: [OutputKeys.MASKS], | Tasks.text_driven_segmentation: [OutputKeys.MASKS], | ||||
| # shop segmentation result for single sample | |||||
| # { | |||||
| # "masks": [ | |||||
| # np.array # 2D array containing only 0, 255 | |||||
| # ] | |||||
| # } | |||||
| Tasks.shop_segmentation: [OutputKeys.MASKS], | |||||
| # movide scene segmentation result for a single video | # movide scene segmentation result for a single video | ||||
| # { | # { | ||||
| # "split_video_num":3, | # "split_video_num":3, | ||||
| @@ -156,7 +156,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_vitl16_segmentation_text-driven-seg'), | 'damo/cv_vitl16_segmentation_text-driven-seg'), | ||||
| Tasks.movie_scene_segmentation: | Tasks.movie_scene_segmentation: | ||||
| (Pipelines.movie_scene_segmentation, | (Pipelines.movie_scene_segmentation, | ||||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | |||||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | |||||
| Tasks.shop_segmentation: (Pipelines.shop_segmentation, | |||||
| 'damo/cv_vitb16_segmentation_shop-seg'), | |||||
| } | } | ||||
| @@ -43,10 +43,10 @@ if TYPE_CHECKING: | |||||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | from .virtual_try_on_pipeline import VirtualTryonPipeline | ||||
| from .shop_segmentation_pipleline import ShopSegmentationPipeline | |||||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | ||||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | ||||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
| @@ -96,6 +96,7 @@ else: | |||||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | ||||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | ||||
| 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], | |||||
| 'easycv_pipeline': [ | 'easycv_pipeline': [ | ||||
| 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', | 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', | ||||
| 'Face2DKeypointsPipeline' | 'Face2DKeypointsPipeline' | ||||
| @@ -0,0 +1,51 @@ | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| Tasks.shop_segmentation, module_name=Pipelines.shop_segmentation) | |||||
| class ShopSegmentationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) | |||||
| result = { | |||||
| 'img': img_tensor, | |||||
| 'ori_h': ori_h, | |||||
| 'ori_w': ori_w, | |||||
| 'crop_h': crop_h, | |||||
| 'crop_w': crop_w | |||||
| } | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| outputs = self.model.inference(input['img']) | |||||
| result = { | |||||
| 'data': outputs, | |||||
| 'ori_h': input['ori_h'], | |||||
| 'ori_w': input['ori_w'], | |||||
| 'crop_h': input['crop_h'], | |||||
| 'crop_w': input['crop_w'], | |||||
| } | |||||
| return result | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| data = self.model.postprocess(inputs['data'], inputs['crop_h'], | |||||
| inputs['crop_w'], inputs['ori_h'], | |||||
| inputs['ori_w']) | |||||
| outputs = {OutputKeys.MASKS: data} | |||||
| return outputs | |||||
| @@ -38,6 +38,7 @@ class CVTasks(object): | |||||
| image_segmentation = 'image-segmentation' | image_segmentation = 'image-segmentation' | ||||
| portrait_matting = 'portrait-matting' | portrait_matting = 'portrait-matting' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | |||||
| # image editing | # image editing | ||||
| skin_retouching = 'skin-retouching' | skin_retouching = 'skin-retouching' | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ShopSegmentationTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_shop_segmentation(self): | |||||
| input_location = 'data/test/images/shop_segmentation.jpg' | |||||
| model_id = 'damo/cv_vitb16_segmentation_shop-seg' | |||||
| shop_seg = pipeline(Tasks.shop_segmentation, model=model_id) | |||||
| result = shop_seg(input_location) | |||||
| import cv2 | |||||
| # result[OutputKeys.MASKS] is segment map result,other keys are not used | |||||
| cv2.imwrite(input_location + '_shopseg.jpg', result[OutputKeys.MASKS]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||