From 4d3716cf4ebd0efc814818709234f93eef8e73c5 Mon Sep 17 00:00:00 2001 From: "xingguang.zxg" Date: Fri, 2 Sep 2022 14:14:47 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E6=96=87=E6=9C=AC=E6=8C=87?= =?UTF-8?q?=E5=AF=BC=E7=9A=84=E8=AF=AD=E4=B9=89=E5=88=86=E5=89=B2=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 文本指导的语义分割模型,根据输入的文本信息,讲图像中对应文本描述的物体分割出来。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942863 --- data/test/images/text_driven_segmentation.jpg | 3 + modelscope/metainfo.py | 2 + .../cv/text_driven_segmentation/__init__.py | 1 + .../cv/text_driven_segmentation/clip.py | 170 ++++++ .../cv/text_driven_segmentation/lseg_base.py | 28 + .../text_driven_segmentation/lseg_blocks.py | 334 +++++++++++ .../cv/text_driven_segmentation/lseg_model.py | 107 ++++ .../cv/text_driven_segmentation/lseg_net.py | 197 +++++++ .../cv/text_driven_segmentation/lseg_vit.py | 543 ++++++++++++++++++ .../cv/text_driven_segmentation/model.py | 458 +++++++++++++++ .../simple_tokenizer.py | 156 +++++ modelscope/outputs.py | 7 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 3 + .../cv/text_driven_segmentation_pipleline.py | 51 ++ modelscope/utils/constant.py | 1 + .../test_text_driven_segmentation.py | 28 + 17 files changed, 2092 insertions(+) create mode 100644 data/test/images/text_driven_segmentation.jpg create mode 100644 modelscope/models/cv/text_driven_segmentation/__init__.py create mode 100644 modelscope/models/cv/text_driven_segmentation/clip.py create mode 100644 modelscope/models/cv/text_driven_segmentation/lseg_base.py create mode 100644 modelscope/models/cv/text_driven_segmentation/lseg_blocks.py create mode 100644 modelscope/models/cv/text_driven_segmentation/lseg_model.py create mode 100644 modelscope/models/cv/text_driven_segmentation/lseg_net.py create mode 100644 modelscope/models/cv/text_driven_segmentation/lseg_vit.py create mode 100644 modelscope/models/cv/text_driven_segmentation/model.py create mode 100644 modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py create mode 100644 modelscope/pipelines/cv/text_driven_segmentation_pipleline.py create mode 100644 tests/pipelines/test_text_driven_segmentation.py diff --git a/data/test/images/text_driven_segmentation.jpg b/data/test/images/text_driven_segmentation.jpg new file mode 100644 index 00000000..e3320b1f --- /dev/null +++ b/data/test/images/text_driven_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4 +size 67861 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index fd653bac..3225710a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -29,6 +29,7 @@ class Models(object): video_summarization = 'pgl-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' + text_driven_segmentation = 'text-driven-segmentation' resnet50_bert = 'resnet50-bert' # EasyCV models @@ -143,6 +144,7 @@ class Pipelines(object): video_summarization = 'googlenet_pgl_video_summarization' image_semantic_segmentation = 'image-semantic-segmentation' image_reid_person = 'passvitb-image-reid-person' + text_driven_segmentation = 'text-driven-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' # nlp tasks diff --git a/modelscope/models/cv/text_driven_segmentation/__init__.py b/modelscope/models/cv/text_driven_segmentation/__init__.py new file mode 100644 index 00000000..46daad78 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/__init__.py @@ -0,0 +1 @@ +from .lseg_base import TextDrivenSegmentation diff --git a/modelscope/models/cv/text_driven_segmentation/clip.py b/modelscope/models/cv/text_driven_segmentation/clip.py new file mode 100644 index 00000000..440cccea --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/clip.py @@ -0,0 +1,170 @@ +""" CLIP +Adapted from https://github.com/openai/CLIP. +Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import hashlib +import os +import urllib +import warnings +from typing import Any, List, Union + +import torch +from PIL import Image +from pkg_resources import packaging +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.7.1'): + warnings.warn('PyTorch version 1.7.1 or higher is recommended') +__all__ = ['load', 'tokenize'] + + +def _convert_image_to_rgb(image): + return image.convert('RGB') + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def load(name: str, + device: Union[str, torch.device] = 'cuda' + if torch.cuda.is_available() else 'cpu', + jit: bool = False, + root: str = None): + + if not jit: + model = build_model().to(device) + if str(device) == 'cpu': + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [ + n for n in device_holder.graph.findAllNodes('prim::Constant') + if 'Device' in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, 'graph') else [] + except RuntimeError: + graphs = [] + + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('prim::Constant'): + if 'value' in node.attributeNames() and str( + node['value']).startswith('cuda'): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == 'cpu': + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, 'graph') else [] + except RuntimeError: + graphs = [] + + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('aten::to'): + inputs = list(node.inputs()) + for i in [ + 1, 2 + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()['value'] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize( + _tokenizer, + texts: Union[str, List[str]], + context_length: int = 77, + truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder['<|startoftext|>'] + eot_token = _tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] + for text in texts] + if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.8.0'): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_base.py b/modelscope/models/cv/text_driven_segmentation/lseg_base.py new file mode 100644 index 00000000..20915396 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_base.py @@ -0,0 +1,28 @@ +""" +Adapted from https://github.com/isl-org/lang-seg. +Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. +""" + +import torch +import torch.nn as nn + +from .lseg_net import LSeg + + +class TextDrivenSegmentation(nn.Module): + + def __init__(self, model_dir): + super(TextDrivenSegmentation, self).__init__() + self.net = LSeg(model_dir=model_dir) + self.model_dir = model_dir + + def forward(self, img, txt_list): + b = img.size()[0] + batch_name_list = txt_list + xout_list = [] + for i in range(b): + labelset = ['others', batch_name_list[i]] + xout = self.net(img[i:i + 1], labelset=labelset) + xout_list.append(xout) + score_map = torch.cat(xout_list, dim=0) + return score_map diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py b/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py new file mode 100644 index 00000000..cb550ab7 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py @@ -0,0 +1,334 @@ +""" +Adapted from https://github.com/isl-org/lang-seg. +Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. +""" + +import torch +import torch.nn as nn + +from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit + + +def _make_encoder( + backbone, + features, + use_pretrained=True, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout='ignore', + enable_attention_hooks=False, +): + if backbone == 'clip_vitl16_384': + clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch([256, 512, 1024, 1024], + features, + groups=groups, + expand=expand) + else: + raise NotImplementedError(f"Backbone '{backbone}' not implemented") + + return clip_pretrained, pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand is True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode='bilinear', align_corners=True) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + + output = self.out_conv(output) + return output diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_model.py b/modelscope/models/cv/text_driven_segmentation/lseg_model.py new file mode 100644 index 00000000..1d7ebdd1 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_model.py @@ -0,0 +1,107 @@ +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.text_driven_segmentation import \ + TextDrivenSegmentation +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() +__all__ = ['TextDrivenSeg'] + + +@MODELS.register_module( + Tasks.text_driven_segmentation, + module_name=Models.text_driven_segmentation) +class TextDrivenSeg(TorchModel): + """ text driven segmentation model. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = TextDrivenSegmentation(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.load_state_dict(pretrained_params) + self.model.eval() + if device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(device_id)) + logger.info('Use GPU: {}'.format(device_id)) + else: + device_id = -1 + logger.info('Use CPU for inference') + self.device_id = device_id + + def preprocess(self, img, size=640): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + h, w, c = img.shape + max_hw = max(h, w) + ratio = 1.0 * size / max_hw + crop_h, crop_w = int(ratio * h), int(ratio * w) + pil_img = Image.fromarray(img) + pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) + np_img = np.array(pil_img, dtype=np.float32) / 255. + for j in range(3): + np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] + img_pad = np.zeros((size, size, 3), dtype=np.float32) + img_pad[:crop_h, :crop_w] = np_img + img_pad = torch.from_numpy(img_pad).permute(2, 0, + 1).unsqueeze(0).float() + return img_pad, h, w, crop_h, crop_w + + def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): + output = np.clip(tensors * 255., a_min=0, a_max=255.) + crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) + pil_output = Image.fromarray(crop_output) + pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) + np_output = np.array(pil_output, dtype=np.uint8) + np_output[np_output < 128] = 0 + np_output[np_output >= 128] = 255 + np_output = np.uint8(np_output) + return np_output + + def forward(self, image, text): + """ + image should be numpy array, dtype=np.uint8, shape: height*width*3 + """ + image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( + image, size=640) + pred = self.inference(image_tensor, text) + msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640) + outputs = {OutputKeys.MASKS: msk} + return outputs + + def inference(self, image, text): + """ + image should be tensor, 1 * 3 * 640 * 640 + """ + with torch.no_grad(): + if self.device_id == -1: + output = self.model(image) + else: + device = torch.device('cuda', self.device_id) + output = self.model(image.to(device), [text]) + output = F.interpolate(output, size=(640, 640), mode='bilinear') + output = F.softmax(output, dim=1) + output = torch.argmax(output, dim=1) + output = output[0] + if self.device_id == -1: + pred = output.data.numpy() + else: + pred = output.data.cpu().numpy() + del output + return pred diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_net.py b/modelscope/models/cv/text_driven_segmentation/lseg_net.py new file mode 100644 index 00000000..1a558c5c --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_net.py @@ -0,0 +1,197 @@ +""" +Adapted from https://github.com/isl-org/lang-seg. +Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. +""" + +import numpy as np +import torch +import torch.nn as nn + +from . import clip +from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom, + Interpolate, _make_encoder, forward_vit) +from .simple_tokenizer import SimpleTokenizer + + +class depthwise_clipseg_conv(nn.Module): + + def __init__(self): + super(depthwise_clipseg_conv, self).__init__() + self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) + + def depthwise_clipseg(self, x, channels): + x = torch.cat( + [self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], + dim=1) + return x + + def forward(self, x): + channels = x.shape[1] + out = self.depthwise_clipseg(x, channels) + return out + + +class depthwise_conv(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1): + super(depthwise_conv, self).__init__() + self.depthwise = nn.Conv2d( + 1, 1, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x): + # support for 4D tensor with NCHW + C, H, W = x.shape[1:] + x = x.reshape(-1, 1, H, W) + x = self.depthwise(x) + x = x.view(-1, C, H, W) + return x + + +class depthwise_block(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): + super(depthwise_block, self).__init__() + self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'lrelu': + self.activation = nn.LeakyReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + + def forward(self, x, act=True): + x = self.depthwise(x) + if act: + x = self.activation(x) + return x + + +class bottleneck_block(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): + super(bottleneck_block, self).__init__() + self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'lrelu': + self.activation = nn.LeakyReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + + def forward(self, x, act=True): + sum_layer = x.max(dim=1, keepdim=True)[0] + x = self.depthwise(x) + x = x + sum_layer + if act: + x = self.activation(x) + return x + + +class BaseModel(torch.nn.Module): + + def load(self, path): + """Load model from file. + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if 'optimizer' in parameters: + parameters = parameters['model'] + + self.load_state_dict(parameters) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + activation=nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class LSeg(BaseModel): + + def __init__( + self, + features=256, + backbone='clip_vitl16_384', + readout='project', + use_bn=True, + model_dir=None, + ): + super(LSeg, self).__init__() + hooks = { + 'clip_vitl16_384': [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( + backbone, + features, + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.logit_scale = nn.Parameter(torch.ones([]) + * np.log(1 / 0.07)).exp() + self.out_c = 512 + self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) + + self.scratch.output_conv = nn.Sequential( + Interpolate(scale_factor=2, mode='bilinear', align_corners=True), ) + + self.tau = 0.07 + self.model_dir = model_dir + self.tokenizer = SimpleTokenizer(model_dir + + '/bpe_simple_vocab_16e6.txt.gz') + + def forward(self, x, labelset=''): + text = clip.tokenize(self.tokenizer, labelset) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + text = text.to(x.device) + text_features = self.clip_pretrained.encode_text(text) + + image_features = self.scratch.head1(path_1) + + imshape = image_features.shape + image_features = image_features.permute(0, 2, 3, + 1).reshape(-1, self.out_c) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + logits_per_image = image_features @ text_features.t() / self.tau + + out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], + -1).permute(0, 3, 1, 2) + + out = self.scratch.output_conv(out) + + return out diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_vit.py b/modelscope/models/cv/text_driven_segmentation/lseg_vit.py new file mode 100644 index 00000000..be2813c2 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_vit.py @@ -0,0 +1,543 @@ +""" +Adapted from https://github.com/isl-org/lang-seg. +Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. +""" + +import math +import types + +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from . import clip + +activations = {} + + +def get_activation(name): + + def hook(model, input, output): + activations[name] = output + + return hook + + +attention = {} + + +def get_attention(name): + + def hook(module, input, output): + x = input[0] + B, N, C = x.shape + qkv = ( + module.qkv(x).reshape(B, N, 3, module.num_heads, + C // module.num_heads).permute( + 2, 0, 3, 1, 4)) + q, k, _ = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * module.scale + + attn = attn.softmax(dim=-1) # [:,:,1,1:] + attention[name] = attn + + return hook + + +def get_mean_attention_map(attn, token, shape): + attn = attn[:, :, token, 1:] + attn = attn.unflatten(2, torch.Size([shape[2] // 16, + shape[3] // 16])).float() + attn = torch.nn.functional.interpolate( + attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0) + + all_attn = torch.mean(attn, 0) + + return all_attn + + +class Slice(nn.Module): + + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential( + nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + # encoder + _ = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations['1'] + layer_2 = pretrained.activations['2'] + layer_3 = pretrained.activations['3'] + layer_4 = pretrained.activations['4'] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size([ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ]), + )) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)]( + layer_1) + layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)]( + layer_2) + layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)]( + layer_3) + layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)]( + layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, :self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=(gs_h, gs_w), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], + w // self.patch_size[0]) + + B = x.shape[0] + + if hasattr(self.patch_embed, 'backbone'): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[ + -1] # last feature if backbone outputs list/tuple of features + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, 'dist_token', None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + gradient_checkpoint = False + for blk in self.blocks: + if gradient_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.norm(x) + + return x + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == 'ignore': + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == 'add': + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == 'project': + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float( + ) # Some weights are in torch.half, ensure it's float for sum on CPU + O, II, J, K = conv_weight.shape + if in_chans == 1: + if II > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, II // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if II != 3: + raise NotImplementedError( + 'Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, + 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +@torch.no_grad() +def _load_weights(model, checkpoint_path, prefix=''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_( + adapt_input_conv(stem.conv.weight.shape[1], + _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_( + _n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_( + _n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_( + _n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_( + _n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_( + _n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_( + _n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], + _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p( + w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', + 1), + model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance( + model.head, nn.Linear + ) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T + for n in ('query', 'key', 'value') + ])) + block.attn.qkv.bias.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) + for n in ('query', 'key', 'value') + ])) + block.attn.proj.weight.copy_( + _n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[ + 0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, + 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + +def _make_pretrained_clip_vitl16_384(pretrained, + use_readout='ignore', + hooks=None, + enable_attention_hooks=False): + clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False) + + # model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + model = timm.create_model('vit_large_patch16_384', pretrained=False) + hooks = [5, 11, 17, 23] if hooks is None else hooks + pretrained = _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + return clip_pretrained, pretrained + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout='ignore', + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook( + get_activation('1')) + pretrained.model.blocks[hooks[1]].register_forward_hook( + get_activation('2')) + pretrained.model.blocks[hooks[2]].register_forward_hook( + get_activation('3')) + pretrained.model.blocks[hooks[3]].register_forward_hook( + get_activation('4')) + + pretrained.activations = activations + + if enable_attention_hooks: + pretrained.model.blocks[hooks[0]].attn.register_forward_hook( + get_attention('attn_1')) + pretrained.model.blocks[hooks[1]].attn.register_forward_hook( + get_attention('attn_2')) + pretrained.model.blocks[hooks[2]].attn.register_forward_hook( + get_attention('attn_3')) + pretrained.model.blocks[hooks[3]].attn.register_forward_hook( + get_attention('attn_4')) + pretrained.attention = attention + + readout_oper = get_readout_oper(vit_features, features, use_readout, + start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, + pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model) + + return pretrained diff --git a/modelscope/models/cv/text_driven_segmentation/model.py b/modelscope/models/cv/text_driven_segmentation/model.py new file mode 100644 index 00000000..ece10bab --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/model.py @@ -0,0 +1,458 @@ +""" +Adapted from https://github.com/isl-org/lang-seg. +Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. +""" + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, width, layers, heads, attn_mask=None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x1 = self.class_embedding.to(x.dtype) + x2 = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(ll): + if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)): + ll.weight.data = ll.weight.data.half() + if ll.bias is not None: + ll.bias.data = ll.bias.data.half() + + if isinstance(ll, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(ll, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(ll, name): + attr = getattr(ll, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(): + model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12) + convert_weights(model) + return model.eval() diff --git a/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py b/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py new file mode 100644 index 00000000..250d680f --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py @@ -0,0 +1,156 @@ +""" CLIP +Adapted from https://github.com/openai/CLIP. +Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + error_list = [] + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + new_word.extend(word[i:]) + error_list.append(err) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 7d6cdb59..6fada2b0 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -243,6 +243,13 @@ TASK_OUTPUTS = { # "output_img": np.ndarray with shape [height, width, 3] # } Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], + # text driven segmentation result for single sample + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.text_driven_segmentation: [OutputKeys.MASKS], # movide scene segmentation result for a single video # { diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c9f0c252..40c237c8 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -149,6 +149,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_vitb_video-single-object-tracking_ostrack'), Tasks.image_reid_person: (Pipelines.image_reid_person, 'damo/cv_passvitb_image-reid-person_market'), + Tasks.text_driven_segmentation: + (Pipelines.text_driven_segmentation, + 'damo/cv_vitl16_segmentation_text-driven-seg'), Tasks.movie_scene_segmentation: (Pipelines.movie_scene_segmentation, 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index f4e6792b..c8cb0c6a 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline + from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline else: @@ -97,6 +98,8 @@ else: 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], 'easycv_pipeline': ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], + 'text_driven_segmentation_pipeline': + ['TextDrivenSegmentationPipeline'], 'movie_scene_segmentation_pipeline': ['MovieSceneSegmentationPipeline'], } diff --git a/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py b/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py new file mode 100644 index 00000000..0985b835 --- /dev/null +++ b/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py @@ -0,0 +1,51 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_driven_segmentation, + module_name=Pipelines.text_driven_segmentation) +class TextDrivenSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Dict) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['image']) + img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) + result = { + 'img': img_tensor, + 'ori_h': ori_h, + 'ori_w': ori_w, + 'crop_h': crop_h, + 'crop_w': crop_w, + 'text': input['text'], + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + outputs = self.model.inference(input['img'], input['text']) + result = { + 'data': outputs, + 'ori_h': input['ori_h'], + 'ori_w': input['ori_w'], + 'crop_h': input['crop_h'], + 'crop_w': input['crop_w'], + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + data = self.model.postprocess(inputs['data'], inputs['crop_h'], + inputs['crop_w'], inputs['ori_h'], + inputs['ori_w']) + outputs = {OutputKeys.MASKS: data} + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2265ef5a..ed1ec798 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -36,6 +36,7 @@ class CVTasks(object): image_segmentation = 'image-segmentation' portrait_matting = 'portrait-matting' + text_driven_segmentation = 'text-driven-segmentation' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_text_driven_segmentation.py b/tests/pipelines/test_text_driven_segmentation.py new file mode 100644 index 00000000..741787d9 --- /dev/null +++ b/tests/pipelines/test_text_driven_segmentation.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextDrivenSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_text_driven_segmentation(self): + input_location = 'data/test/images/text_driven_segmentation.jpg' + test_input = { + 'image': input_location, + 'text': 'bear', + } + model_id = 'damo/cv_vitl16_segmentation_text-driven-seg' + shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id) + result = shop_seg(test_input) + import cv2 + # result[OutputKeys.MASKS] is segment map result,other keys are not used + cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS]) + + +if __name__ == '__main__': + unittest.main()