文本指导的语义分割模型,根据输入的文本信息,讲图像中对应文本描述的物体分割出来。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942863
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4 | |||
| size 67861 | |||
| @@ -29,6 +29,7 @@ class Models(object): | |||
| video_summarization = 'pgl-video-summarization' | |||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| resnet50_bert = 'resnet50-bert' | |||
| # EasyCV models | |||
| @@ -143,6 +144,7 @@ class Pipelines(object): | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| image_semantic_segmentation = 'image-semantic-segmentation' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
| # nlp tasks | |||
| @@ -0,0 +1 @@ | |||
| from .lseg_base import TextDrivenSegmentation | |||
| @@ -0,0 +1,170 @@ | |||
| """ CLIP | |||
| Adapted from https://github.com/openai/CLIP. | |||
| Originally MIT License, Copyright (c) 2021 OpenAI. | |||
| """ | |||
| import hashlib | |||
| import os | |||
| import urllib | |||
| import warnings | |||
| from typing import Any, List, Union | |||
| import torch | |||
| from PIL import Image | |||
| from pkg_resources import packaging | |||
| from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, | |||
| ToTensor) | |||
| from tqdm import tqdm | |||
| from .model import build_model | |||
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |||
| try: | |||
| from torchvision.transforms import InterpolationMode | |||
| BICUBIC = InterpolationMode.BICUBIC | |||
| except ImportError: | |||
| BICUBIC = Image.BICUBIC | |||
| if packaging.version.parse( | |||
| torch.__version__) < packaging.version.parse('1.7.1'): | |||
| warnings.warn('PyTorch version 1.7.1 or higher is recommended') | |||
| __all__ = ['load', 'tokenize'] | |||
| def _convert_image_to_rgb(image): | |||
| return image.convert('RGB') | |||
| def _transform(n_px): | |||
| return Compose([ | |||
| Resize(n_px, interpolation=BICUBIC), | |||
| CenterCrop(n_px), | |||
| _convert_image_to_rgb, | |||
| ToTensor(), | |||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| def load(name: str, | |||
| device: Union[str, torch.device] = 'cuda' | |||
| if torch.cuda.is_available() else 'cpu', | |||
| jit: bool = False, | |||
| root: str = None): | |||
| if not jit: | |||
| model = build_model().to(device) | |||
| if str(device) == 'cpu': | |||
| model.float() | |||
| return model, _transform(model.visual.input_resolution) | |||
| # patch the device names | |||
| device_holder = torch.jit.trace( | |||
| lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | |||
| device_node = [ | |||
| n for n in device_holder.graph.findAllNodes('prim::Constant') | |||
| if 'Device' in repr(n) | |||
| ][-1] | |||
| def patch_device(module): | |||
| try: | |||
| graphs = [module.graph] if hasattr(module, 'graph') else [] | |||
| except RuntimeError: | |||
| graphs = [] | |||
| if hasattr(module, 'forward1'): | |||
| graphs.append(module.forward1.graph) | |||
| for graph in graphs: | |||
| for node in graph.findAllNodes('prim::Constant'): | |||
| if 'value' in node.attributeNames() and str( | |||
| node['value']).startswith('cuda'): | |||
| node.copyAttributes(device_node) | |||
| model.apply(patch_device) | |||
| patch_device(model.encode_image) | |||
| patch_device(model.encode_text) | |||
| # patch dtype to float32 on CPU | |||
| if str(device) == 'cpu': | |||
| float_holder = torch.jit.trace( | |||
| lambda: torch.ones([]).float(), example_inputs=[]) | |||
| float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] | |||
| float_node = float_input.node() | |||
| def patch_float(module): | |||
| try: | |||
| graphs = [module.graph] if hasattr(module, 'graph') else [] | |||
| except RuntimeError: | |||
| graphs = [] | |||
| if hasattr(module, 'forward1'): | |||
| graphs.append(module.forward1.graph) | |||
| for graph in graphs: | |||
| for node in graph.findAllNodes('aten::to'): | |||
| inputs = list(node.inputs()) | |||
| for i in [ | |||
| 1, 2 | |||
| ]: # dtype can be the second or third argument to aten::to() | |||
| if inputs[i].node()['value'] == 5: | |||
| inputs[i].node().copyAttributes(float_node) | |||
| model.apply(patch_float) | |||
| patch_float(model.encode_image) | |||
| patch_float(model.encode_text) | |||
| model.float() | |||
| return model, _transform(model.input_resolution.item()) | |||
| def tokenize( | |||
| _tokenizer, | |||
| texts: Union[str, List[str]], | |||
| context_length: int = 77, | |||
| truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: | |||
| """ | |||
| Returns the tokenized representation of given input string(s) | |||
| Parameters | |||
| ---------- | |||
| texts : Union[str, List[str]] | |||
| An input string or a list of input strings to tokenize | |||
| context_length : int | |||
| The context length to use; all CLIP models use 77 as the context length | |||
| truncate: bool | |||
| Whether to truncate the text in case its encoding is longer than the context length | |||
| Returns | |||
| ------- | |||
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |||
| We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |||
| """ | |||
| if isinstance(texts, str): | |||
| texts = [texts] | |||
| sot_token = _tokenizer.encoder['<|startoftext|>'] | |||
| eot_token = _tokenizer.encoder['<|endoftext|>'] | |||
| all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] | |||
| for text in texts] | |||
| if packaging.version.parse( | |||
| torch.__version__) < packaging.version.parse('1.8.0'): | |||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||
| else: | |||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |||
| for i, tokens in enumerate(all_tokens): | |||
| if len(tokens) > context_length: | |||
| if truncate: | |||
| tokens = tokens[:context_length] | |||
| tokens[-1] = eot_token | |||
| else: | |||
| raise RuntimeError( | |||
| f'Input {texts[i]} is too long for context length {context_length}' | |||
| ) | |||
| result[i, :len(tokens)] = torch.tensor(tokens) | |||
| return result | |||
| @@ -0,0 +1,28 @@ | |||
| """ | |||
| Adapted from https://github.com/isl-org/lang-seg. | |||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| from .lseg_net import LSeg | |||
| class TextDrivenSegmentation(nn.Module): | |||
| def __init__(self, model_dir): | |||
| super(TextDrivenSegmentation, self).__init__() | |||
| self.net = LSeg(model_dir=model_dir) | |||
| self.model_dir = model_dir | |||
| def forward(self, img, txt_list): | |||
| b = img.size()[0] | |||
| batch_name_list = txt_list | |||
| xout_list = [] | |||
| for i in range(b): | |||
| labelset = ['others', batch_name_list[i]] | |||
| xout = self.net(img[i:i + 1], labelset=labelset) | |||
| xout_list.append(xout) | |||
| score_map = torch.cat(xout_list, dim=0) | |||
| return score_map | |||
| @@ -0,0 +1,334 @@ | |||
| """ | |||
| Adapted from https://github.com/isl-org/lang-seg. | |||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit | |||
| def _make_encoder( | |||
| backbone, | |||
| features, | |||
| use_pretrained=True, | |||
| groups=1, | |||
| expand=False, | |||
| exportable=True, | |||
| hooks=None, | |||
| use_vit_only=False, | |||
| use_readout='ignore', | |||
| enable_attention_hooks=False, | |||
| ): | |||
| if backbone == 'clip_vitl16_384': | |||
| clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( | |||
| use_pretrained, | |||
| hooks=hooks, | |||
| use_readout=use_readout, | |||
| enable_attention_hooks=enable_attention_hooks, | |||
| ) | |||
| scratch = _make_scratch([256, 512, 1024, 1024], | |||
| features, | |||
| groups=groups, | |||
| expand=expand) | |||
| else: | |||
| raise NotImplementedError(f"Backbone '{backbone}' not implemented") | |||
| return clip_pretrained, pretrained, scratch | |||
| def _make_scratch(in_shape, out_shape, groups=1, expand=False): | |||
| scratch = nn.Module() | |||
| out_shape1 = out_shape | |||
| out_shape2 = out_shape | |||
| out_shape3 = out_shape | |||
| out_shape4 = out_shape | |||
| if expand is True: | |||
| out_shape1 = out_shape | |||
| out_shape2 = out_shape * 2 | |||
| out_shape3 = out_shape * 4 | |||
| out_shape4 = out_shape * 8 | |||
| scratch.layer1_rn = nn.Conv2d( | |||
| in_shape[0], | |||
| out_shape1, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False, | |||
| groups=groups, | |||
| ) | |||
| scratch.layer2_rn = nn.Conv2d( | |||
| in_shape[1], | |||
| out_shape2, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False, | |||
| groups=groups, | |||
| ) | |||
| scratch.layer3_rn = nn.Conv2d( | |||
| in_shape[2], | |||
| out_shape3, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False, | |||
| groups=groups, | |||
| ) | |||
| scratch.layer4_rn = nn.Conv2d( | |||
| in_shape[3], | |||
| out_shape4, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False, | |||
| groups=groups, | |||
| ) | |||
| return scratch | |||
| class Interpolate(nn.Module): | |||
| """Interpolation module.""" | |||
| def __init__(self, scale_factor, mode, align_corners=False): | |||
| """Init. | |||
| Args: | |||
| scale_factor (float): scaling | |||
| mode (str): interpolation mode | |||
| """ | |||
| super(Interpolate, self).__init__() | |||
| self.interp = nn.functional.interpolate | |||
| self.scale_factor = scale_factor | |||
| self.mode = mode | |||
| self.align_corners = align_corners | |||
| def forward(self, x): | |||
| """Forward pass. | |||
| Args: | |||
| x (tensor): input | |||
| Returns: | |||
| tensor: interpolated data | |||
| """ | |||
| x = self.interp( | |||
| x, | |||
| scale_factor=self.scale_factor, | |||
| mode=self.mode, | |||
| align_corners=self.align_corners, | |||
| ) | |||
| return x | |||
| class ResidualConvUnit(nn.Module): | |||
| """Residual convolution module.""" | |||
| def __init__(self, features): | |||
| """Init. | |||
| Args: | |||
| features (int): number of features | |||
| """ | |||
| super().__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||
| self.conv2 = nn.Conv2d( | |||
| features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| def forward(self, x): | |||
| """Forward pass. | |||
| Args: | |||
| x (tensor): input | |||
| Returns: | |||
| tensor: output | |||
| """ | |||
| out = self.relu(x) | |||
| out = self.conv1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| return out + x | |||
| class FeatureFusionBlock(nn.Module): | |||
| """Feature fusion block.""" | |||
| def __init__(self, features): | |||
| """Init. | |||
| Args: | |||
| features (int): number of features | |||
| """ | |||
| super(FeatureFusionBlock, self).__init__() | |||
| self.resConfUnit1 = ResidualConvUnit(features) | |||
| self.resConfUnit2 = ResidualConvUnit(features) | |||
| def forward(self, *xs): | |||
| """Forward pass. | |||
| Returns: | |||
| tensor: output | |||
| """ | |||
| output = xs[0] | |||
| if len(xs) == 2: | |||
| output += self.resConfUnit1(xs[1]) | |||
| output = self.resConfUnit2(output) | |||
| output = nn.functional.interpolate( | |||
| output, scale_factor=2, mode='bilinear', align_corners=True) | |||
| return output | |||
| class ResidualConvUnit_custom(nn.Module): | |||
| """Residual convolution module.""" | |||
| def __init__(self, features, activation, bn): | |||
| """Init. | |||
| Args: | |||
| features (int): number of features | |||
| """ | |||
| super().__init__() | |||
| self.bn = bn | |||
| self.groups = 1 | |||
| self.conv1 = nn.Conv2d( | |||
| features, | |||
| features, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=not self.bn, | |||
| groups=self.groups, | |||
| ) | |||
| self.conv2 = nn.Conv2d( | |||
| features, | |||
| features, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=not self.bn, | |||
| groups=self.groups, | |||
| ) | |||
| if self.bn is True: | |||
| self.bn1 = nn.BatchNorm2d(features) | |||
| self.bn2 = nn.BatchNorm2d(features) | |||
| self.activation = activation | |||
| self.skip_add = nn.quantized.FloatFunctional() | |||
| def forward(self, x): | |||
| """Forward pass. | |||
| Args: | |||
| x (tensor): input | |||
| Returns: | |||
| tensor: output | |||
| """ | |||
| out = self.activation(x) | |||
| out = self.conv1(out) | |||
| if self.bn is True: | |||
| out = self.bn1(out) | |||
| out = self.activation(out) | |||
| out = self.conv2(out) | |||
| if self.bn is True: | |||
| out = self.bn2(out) | |||
| if self.groups > 1: | |||
| out = self.conv_merge(out) | |||
| return self.skip_add.add(out, x) | |||
| class FeatureFusionBlock_custom(nn.Module): | |||
| """Feature fusion block.""" | |||
| def __init__( | |||
| self, | |||
| features, | |||
| activation, | |||
| deconv=False, | |||
| bn=False, | |||
| expand=False, | |||
| align_corners=True, | |||
| ): | |||
| """Init. | |||
| Args: | |||
| features (int): number of features | |||
| """ | |||
| super(FeatureFusionBlock_custom, self).__init__() | |||
| self.deconv = deconv | |||
| self.align_corners = align_corners | |||
| self.groups = 1 | |||
| self.expand = expand | |||
| out_features = features | |||
| if self.expand is True: | |||
| out_features = features // 2 | |||
| self.out_conv = nn.Conv2d( | |||
| features, | |||
| out_features, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True, | |||
| groups=1, | |||
| ) | |||
| self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) | |||
| self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) | |||
| self.skip_add = nn.quantized.FloatFunctional() | |||
| def forward(self, *xs): | |||
| """Forward pass. | |||
| Returns: | |||
| tensor: output | |||
| """ | |||
| output = xs[0] | |||
| if len(xs) == 2: | |||
| res = self.resConfUnit1(xs[1]) | |||
| output = self.skip_add.add(output, res) | |||
| output = self.resConfUnit2(output) | |||
| output = nn.functional.interpolate( | |||
| output, | |||
| scale_factor=2, | |||
| mode='bilinear', | |||
| align_corners=self.align_corners) | |||
| output = self.out_conv(output) | |||
| return output | |||
| @@ -0,0 +1,107 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from PIL import Image | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.cv.text_driven_segmentation import \ | |||
| TextDrivenSegmentation | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['TextDrivenSeg'] | |||
| @MODELS.register_module( | |||
| Tasks.text_driven_segmentation, | |||
| module_name=Models.text_driven_segmentation) | |||
| class TextDrivenSeg(TorchModel): | |||
| """ text driven segmentation model. | |||
| """ | |||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
| super().__init__( | |||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
| self.model = TextDrivenSegmentation(model_dir=model_dir) | |||
| pretrained_params = torch.load('{}/{}'.format( | |||
| model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||
| self.model.load_state_dict(pretrained_params) | |||
| self.model.eval() | |||
| if device_id >= 0 and torch.cuda.is_available(): | |||
| self.model.to('cuda:{}'.format(device_id)) | |||
| logger.info('Use GPU: {}'.format(device_id)) | |||
| else: | |||
| device_id = -1 | |||
| logger.info('Use CPU for inference') | |||
| self.device_id = device_id | |||
| def preprocess(self, img, size=640): | |||
| mean = [0.48145466, 0.4578275, 0.40821073] | |||
| std = [0.26862954, 0.26130258, 0.27577711] | |||
| h, w, c = img.shape | |||
| max_hw = max(h, w) | |||
| ratio = 1.0 * size / max_hw | |||
| crop_h, crop_w = int(ratio * h), int(ratio * w) | |||
| pil_img = Image.fromarray(img) | |||
| pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) | |||
| np_img = np.array(pil_img, dtype=np.float32) / 255. | |||
| for j in range(3): | |||
| np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] | |||
| img_pad = np.zeros((size, size, 3), dtype=np.float32) | |||
| img_pad[:crop_h, :crop_w] = np_img | |||
| img_pad = torch.from_numpy(img_pad).permute(2, 0, | |||
| 1).unsqueeze(0).float() | |||
| return img_pad, h, w, crop_h, crop_w | |||
| def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): | |||
| output = np.clip(tensors * 255., a_min=0, a_max=255.) | |||
| crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) | |||
| pil_output = Image.fromarray(crop_output) | |||
| pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) | |||
| np_output = np.array(pil_output, dtype=np.uint8) | |||
| np_output[np_output < 128] = 0 | |||
| np_output[np_output >= 128] = 255 | |||
| np_output = np.uint8(np_output) | |||
| return np_output | |||
| def forward(self, image, text): | |||
| """ | |||
| image should be numpy array, dtype=np.uint8, shape: height*width*3 | |||
| """ | |||
| image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( | |||
| image, size=640) | |||
| pred = self.inference(image_tensor, text) | |||
| msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640) | |||
| outputs = {OutputKeys.MASKS: msk} | |||
| return outputs | |||
| def inference(self, image, text): | |||
| """ | |||
| image should be tensor, 1 * 3 * 640 * 640 | |||
| """ | |||
| with torch.no_grad(): | |||
| if self.device_id == -1: | |||
| output = self.model(image) | |||
| else: | |||
| device = torch.device('cuda', self.device_id) | |||
| output = self.model(image.to(device), [text]) | |||
| output = F.interpolate(output, size=(640, 640), mode='bilinear') | |||
| output = F.softmax(output, dim=1) | |||
| output = torch.argmax(output, dim=1) | |||
| output = output[0] | |||
| if self.device_id == -1: | |||
| pred = output.data.numpy() | |||
| else: | |||
| pred = output.data.cpu().numpy() | |||
| del output | |||
| return pred | |||
| @@ -0,0 +1,197 @@ | |||
| """ | |||
| Adapted from https://github.com/isl-org/lang-seg. | |||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from . import clip | |||
| from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom, | |||
| Interpolate, _make_encoder, forward_vit) | |||
| from .simple_tokenizer import SimpleTokenizer | |||
| class depthwise_clipseg_conv(nn.Module): | |||
| def __init__(self): | |||
| super(depthwise_clipseg_conv, self).__init__() | |||
| self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) | |||
| def depthwise_clipseg(self, x, channels): | |||
| x = torch.cat( | |||
| [self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], | |||
| dim=1) | |||
| return x | |||
| def forward(self, x): | |||
| channels = x.shape[1] | |||
| out = self.depthwise_clipseg(x, channels) | |||
| return out | |||
| class depthwise_conv(nn.Module): | |||
| def __init__(self, kernel_size=3, stride=1, padding=1): | |||
| super(depthwise_conv, self).__init__() | |||
| self.depthwise = nn.Conv2d( | |||
| 1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | |||
| def forward(self, x): | |||
| # support for 4D tensor with NCHW | |||
| C, H, W = x.shape[1:] | |||
| x = x.reshape(-1, 1, H, W) | |||
| x = self.depthwise(x) | |||
| x = x.view(-1, C, H, W) | |||
| return x | |||
| class depthwise_block(nn.Module): | |||
| def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||
| super(depthwise_block, self).__init__() | |||
| self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||
| if activation == 'relu': | |||
| self.activation = nn.ReLU() | |||
| elif activation == 'lrelu': | |||
| self.activation = nn.LeakyReLU() | |||
| elif activation == 'tanh': | |||
| self.activation = nn.Tanh() | |||
| def forward(self, x, act=True): | |||
| x = self.depthwise(x) | |||
| if act: | |||
| x = self.activation(x) | |||
| return x | |||
| class bottleneck_block(nn.Module): | |||
| def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||
| super(bottleneck_block, self).__init__() | |||
| self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||
| if activation == 'relu': | |||
| self.activation = nn.ReLU() | |||
| elif activation == 'lrelu': | |||
| self.activation = nn.LeakyReLU() | |||
| elif activation == 'tanh': | |||
| self.activation = nn.Tanh() | |||
| def forward(self, x, act=True): | |||
| sum_layer = x.max(dim=1, keepdim=True)[0] | |||
| x = self.depthwise(x) | |||
| x = x + sum_layer | |||
| if act: | |||
| x = self.activation(x) | |||
| return x | |||
| class BaseModel(torch.nn.Module): | |||
| def load(self, path): | |||
| """Load model from file. | |||
| Args: | |||
| path (str): file path | |||
| """ | |||
| parameters = torch.load(path, map_location=torch.device('cpu')) | |||
| if 'optimizer' in parameters: | |||
| parameters = parameters['model'] | |||
| self.load_state_dict(parameters) | |||
| def _make_fusion_block(features, use_bn): | |||
| return FeatureFusionBlock_custom( | |||
| features, | |||
| activation=nn.ReLU(False), | |||
| deconv=False, | |||
| bn=use_bn, | |||
| expand=False, | |||
| align_corners=True, | |||
| ) | |||
| class LSeg(BaseModel): | |||
| def __init__( | |||
| self, | |||
| features=256, | |||
| backbone='clip_vitl16_384', | |||
| readout='project', | |||
| use_bn=True, | |||
| model_dir=None, | |||
| ): | |||
| super(LSeg, self).__init__() | |||
| hooks = { | |||
| 'clip_vitl16_384': [5, 11, 17, 23], | |||
| } | |||
| # Instantiate backbone and reassemble blocks | |||
| self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( | |||
| backbone, | |||
| features, | |||
| groups=1, | |||
| expand=False, | |||
| exportable=False, | |||
| hooks=hooks[backbone], | |||
| use_readout=readout, | |||
| ) | |||
| self.scratch.refinenet1 = _make_fusion_block(features, use_bn) | |||
| self.scratch.refinenet2 = _make_fusion_block(features, use_bn) | |||
| self.scratch.refinenet3 = _make_fusion_block(features, use_bn) | |||
| self.scratch.refinenet4 = _make_fusion_block(features, use_bn) | |||
| self.logit_scale = nn.Parameter(torch.ones([]) | |||
| * np.log(1 / 0.07)).exp() | |||
| self.out_c = 512 | |||
| self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) | |||
| self.scratch.output_conv = nn.Sequential( | |||
| Interpolate(scale_factor=2, mode='bilinear', align_corners=True), ) | |||
| self.tau = 0.07 | |||
| self.model_dir = model_dir | |||
| self.tokenizer = SimpleTokenizer(model_dir | |||
| + '/bpe_simple_vocab_16e6.txt.gz') | |||
| def forward(self, x, labelset=''): | |||
| text = clip.tokenize(self.tokenizer, labelset) | |||
| layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) | |||
| layer_1_rn = self.scratch.layer1_rn(layer_1) | |||
| layer_2_rn = self.scratch.layer2_rn(layer_2) | |||
| layer_3_rn = self.scratch.layer3_rn(layer_3) | |||
| layer_4_rn = self.scratch.layer4_rn(layer_4) | |||
| path_4 = self.scratch.refinenet4(layer_4_rn) | |||
| path_3 = self.scratch.refinenet3(path_4, layer_3_rn) | |||
| path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | |||
| path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | |||
| text = text.to(x.device) | |||
| text_features = self.clip_pretrained.encode_text(text) | |||
| image_features = self.scratch.head1(path_1) | |||
| imshape = image_features.shape | |||
| image_features = image_features.permute(0, 2, 3, | |||
| 1).reshape(-1, self.out_c) | |||
| # normalized features | |||
| image_features = image_features / image_features.norm( | |||
| dim=-1, keepdim=True) | |||
| text_features = text_features / text_features.norm( | |||
| dim=-1, keepdim=True) | |||
| logits_per_image = image_features @ text_features.t() / self.tau | |||
| out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], | |||
| -1).permute(0, 3, 1, 2) | |||
| out = self.scratch.output_conv(out) | |||
| return out | |||
| @@ -0,0 +1,543 @@ | |||
| """ | |||
| Adapted from https://github.com/isl-org/lang-seg. | |||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
| """ | |||
| import math | |||
| import types | |||
| import timm | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as checkpoint | |||
| from . import clip | |||
| activations = {} | |||
| def get_activation(name): | |||
| def hook(model, input, output): | |||
| activations[name] = output | |||
| return hook | |||
| attention = {} | |||
| def get_attention(name): | |||
| def hook(module, input, output): | |||
| x = input[0] | |||
| B, N, C = x.shape | |||
| qkv = ( | |||
| module.qkv(x).reshape(B, N, 3, module.num_heads, | |||
| C // module.num_heads).permute( | |||
| 2, 0, 3, 1, 4)) | |||
| q, k, _ = ( | |||
| qkv[0], | |||
| qkv[1], | |||
| qkv[2], | |||
| ) # make torchscript happy (cannot use tensor as tuple) | |||
| attn = (q @ k.transpose(-2, -1)) * module.scale | |||
| attn = attn.softmax(dim=-1) # [:,:,1,1:] | |||
| attention[name] = attn | |||
| return hook | |||
| def get_mean_attention_map(attn, token, shape): | |||
| attn = attn[:, :, token, 1:] | |||
| attn = attn.unflatten(2, torch.Size([shape[2] // 16, | |||
| shape[3] // 16])).float() | |||
| attn = torch.nn.functional.interpolate( | |||
| attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0) | |||
| all_attn = torch.mean(attn, 0) | |||
| return all_attn | |||
| class Slice(nn.Module): | |||
| def __init__(self, start_index=1): | |||
| super(Slice, self).__init__() | |||
| self.start_index = start_index | |||
| def forward(self, x): | |||
| return x[:, self.start_index:] | |||
| class AddReadout(nn.Module): | |||
| def __init__(self, start_index=1): | |||
| super(AddReadout, self).__init__() | |||
| self.start_index = start_index | |||
| def forward(self, x): | |||
| if self.start_index == 2: | |||
| readout = (x[:, 0] + x[:, 1]) / 2 | |||
| else: | |||
| readout = x[:, 0] | |||
| return x[:, self.start_index:] + readout.unsqueeze(1) | |||
| class ProjectReadout(nn.Module): | |||
| def __init__(self, in_features, start_index=1): | |||
| super(ProjectReadout, self).__init__() | |||
| self.start_index = start_index | |||
| self.project = nn.Sequential( | |||
| nn.Linear(2 * in_features, in_features), nn.GELU()) | |||
| def forward(self, x): | |||
| readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) | |||
| features = torch.cat((x[:, self.start_index:], readout), -1) | |||
| return self.project(features) | |||
| class Transpose(nn.Module): | |||
| def __init__(self, dim0, dim1): | |||
| super(Transpose, self).__init__() | |||
| self.dim0 = dim0 | |||
| self.dim1 = dim1 | |||
| def forward(self, x): | |||
| x = x.transpose(self.dim0, self.dim1) | |||
| return x | |||
| def forward_vit(pretrained, x): | |||
| b, c, h, w = x.shape | |||
| # encoder | |||
| _ = pretrained.model.forward_flex(x) | |||
| layer_1 = pretrained.activations['1'] | |||
| layer_2 = pretrained.activations['2'] | |||
| layer_3 = pretrained.activations['3'] | |||
| layer_4 = pretrained.activations['4'] | |||
| layer_1 = pretrained.act_postprocess1[0:2](layer_1) | |||
| layer_2 = pretrained.act_postprocess2[0:2](layer_2) | |||
| layer_3 = pretrained.act_postprocess3[0:2](layer_3) | |||
| layer_4 = pretrained.act_postprocess4[0:2](layer_4) | |||
| unflatten = nn.Sequential( | |||
| nn.Unflatten( | |||
| 2, | |||
| torch.Size([ | |||
| h // pretrained.model.patch_size[1], | |||
| w // pretrained.model.patch_size[0], | |||
| ]), | |||
| )) | |||
| if layer_1.ndim == 3: | |||
| layer_1 = unflatten(layer_1) | |||
| if layer_2.ndim == 3: | |||
| layer_2 = unflatten(layer_2) | |||
| if layer_3.ndim == 3: | |||
| layer_3 = unflatten(layer_3) | |||
| if layer_4.ndim == 3: | |||
| layer_4 = unflatten(layer_4) | |||
| layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)]( | |||
| layer_1) | |||
| layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)]( | |||
| layer_2) | |||
| layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)]( | |||
| layer_3) | |||
| layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)]( | |||
| layer_4) | |||
| return layer_1, layer_2, layer_3, layer_4 | |||
| def _resize_pos_embed(self, posemb, gs_h, gs_w): | |||
| posemb_tok, posemb_grid = ( | |||
| posemb[:, :self.start_index], | |||
| posemb[0, self.start_index:], | |||
| ) | |||
| gs_old = int(math.sqrt(len(posemb_grid))) | |||
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||
| -1).permute(0, 3, 1, 2) | |||
| posemb_grid = F.interpolate( | |||
| posemb_grid, size=(gs_h, gs_w), mode='bilinear') | |||
| posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |||
| posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |||
| return posemb | |||
| def forward_flex(self, x): | |||
| b, c, h, w = x.shape | |||
| pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], | |||
| w // self.patch_size[0]) | |||
| B = x.shape[0] | |||
| if hasattr(self.patch_embed, 'backbone'): | |||
| x = self.patch_embed.backbone(x) | |||
| if isinstance(x, (list, tuple)): | |||
| x = x[ | |||
| -1] # last feature if backbone outputs list/tuple of features | |||
| x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) | |||
| if getattr(self, 'dist_token', None) is not None: | |||
| cls_tokens = self.cls_token.expand( | |||
| B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
| dist_token = self.dist_token.expand(B, -1, -1) | |||
| x = torch.cat((cls_tokens, dist_token, x), dim=1) | |||
| else: | |||
| cls_tokens = self.cls_token.expand( | |||
| B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
| x = torch.cat((cls_tokens, x), dim=1) | |||
| x = x + pos_embed | |||
| x = self.pos_drop(x) | |||
| gradient_checkpoint = False | |||
| for blk in self.blocks: | |||
| if gradient_checkpoint: | |||
| x = checkpoint.checkpoint(blk, x) | |||
| else: | |||
| x = blk(x) | |||
| x = self.norm(x) | |||
| return x | |||
| def get_readout_oper(vit_features, features, use_readout, start_index=1): | |||
| if use_readout == 'ignore': | |||
| readout_oper = [Slice(start_index)] * len(features) | |||
| elif use_readout == 'add': | |||
| readout_oper = [AddReadout(start_index)] * len(features) | |||
| elif use_readout == 'project': | |||
| readout_oper = [ | |||
| ProjectReadout(vit_features, start_index) for out_feat in features | |||
| ] | |||
| else: | |||
| assert ( | |||
| False | |||
| ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" | |||
| return readout_oper | |||
| def adapt_input_conv(in_chans, conv_weight): | |||
| conv_type = conv_weight.dtype | |||
| conv_weight = conv_weight.float( | |||
| ) # Some weights are in torch.half, ensure it's float for sum on CPU | |||
| O, II, J, K = conv_weight.shape | |||
| if in_chans == 1: | |||
| if II > 3: | |||
| assert conv_weight.shape[1] % 3 == 0 | |||
| # For models with space2depth stems | |||
| conv_weight = conv_weight.reshape(O, II // 3, 3, J, K) | |||
| conv_weight = conv_weight.sum(dim=2, keepdim=False) | |||
| else: | |||
| conv_weight = conv_weight.sum(dim=1, keepdim=True) | |||
| elif in_chans != 3: | |||
| if II != 3: | |||
| raise NotImplementedError( | |||
| 'Weight format not supported by conversion.') | |||
| else: | |||
| # NOTE this strategy should be better than random init, but there could be other combinations of | |||
| # the original RGB input layer weights that'd work better for specific cases. | |||
| repeat = int(math.ceil(in_chans / 3)) | |||
| conv_weight = conv_weight.repeat(1, repeat, 1, | |||
| 1)[:, :in_chans, :, :] | |||
| conv_weight *= (3 / float(in_chans)) | |||
| conv_weight = conv_weight.to(conv_type) | |||
| return conv_weight | |||
| @torch.no_grad() | |||
| def _load_weights(model, checkpoint_path, prefix=''): | |||
| """ Load weights from .npz checkpoints for official Google Brain Flax implementation | |||
| """ | |||
| import numpy as np | |||
| def _n2p(w, t=True): | |||
| if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | |||
| w = w.flatten() | |||
| if t: | |||
| if w.ndim == 4: | |||
| w = w.transpose([3, 2, 0, 1]) | |||
| elif w.ndim == 3: | |||
| w = w.transpose([2, 0, 1]) | |||
| elif w.ndim == 2: | |||
| w = w.transpose([1, 0]) | |||
| return torch.from_numpy(w) | |||
| w = np.load(checkpoint_path) | |||
| if not prefix and 'opt/target/embedding/kernel' in w: | |||
| prefix = 'opt/target/' | |||
| if hasattr(model.patch_embed, 'backbone'): | |||
| # hybrid | |||
| backbone = model.patch_embed.backbone | |||
| stem_only = not hasattr(backbone, 'stem') | |||
| stem = backbone if stem_only else backbone.stem | |||
| stem.conv.weight.copy_( | |||
| adapt_input_conv(stem.conv.weight.shape[1], | |||
| _n2p(w[f'{prefix}conv_root/kernel']))) | |||
| stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) | |||
| stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) | |||
| if not stem_only: | |||
| for i, stage in enumerate(backbone.stages): | |||
| for j, block in enumerate(stage.blocks): | |||
| bp = f'{prefix}block{i + 1}/unit{j + 1}/' | |||
| for r in range(3): | |||
| getattr(block, f'conv{r + 1}').weight.copy_( | |||
| _n2p(w[f'{bp}conv{r + 1}/kernel'])) | |||
| getattr(block, f'norm{r + 1}').weight.copy_( | |||
| _n2p(w[f'{bp}gn{r + 1}/scale'])) | |||
| getattr(block, f'norm{r + 1}').bias.copy_( | |||
| _n2p(w[f'{bp}gn{r + 1}/bias'])) | |||
| if block.downsample is not None: | |||
| block.downsample.conv.weight.copy_( | |||
| _n2p(w[f'{bp}conv_proj/kernel'])) | |||
| block.downsample.norm.weight.copy_( | |||
| _n2p(w[f'{bp}gn_proj/scale'])) | |||
| block.downsample.norm.bias.copy_( | |||
| _n2p(w[f'{bp}gn_proj/bias'])) | |||
| embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | |||
| else: | |||
| embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], | |||
| _n2p(w[f'{prefix}embedding/kernel'])) | |||
| model.patch_embed.proj.weight.copy_(embed_conv_w) | |||
| model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | |||
| model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | |||
| pos_embed_w = _n2p( | |||
| w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) | |||
| if pos_embed_w.shape != model.pos_embed.shape: | |||
| pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights | |||
| pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', | |||
| 1), | |||
| model.patch_embed.grid_size) | |||
| model.pos_embed.copy_(pos_embed_w) | |||
| model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | |||
| model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | |||
| if isinstance( | |||
| model.head, nn.Linear | |||
| ) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: | |||
| model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | |||
| model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | |||
| # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights | |||
| # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: | |||
| # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) | |||
| # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) | |||
| for i, block in enumerate(model.blocks.children()): | |||
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | |||
| mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' | |||
| block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | |||
| block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | |||
| block.attn.qkv.weight.copy_( | |||
| torch.cat([ | |||
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T | |||
| for n in ('query', 'key', 'value') | |||
| ])) | |||
| block.attn.qkv.bias.copy_( | |||
| torch.cat([ | |||
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) | |||
| for n in ('query', 'key', 'value') | |||
| ])) | |||
| block.attn.proj.weight.copy_( | |||
| _n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | |||
| block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | |||
| for r in range(2): | |||
| getattr(block.mlp, f'fc{r + 1}').weight.copy_( | |||
| _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) | |||
| getattr(block.mlp, f'fc{r + 1}').bias.copy_( | |||
| _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) | |||
| block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) | |||
| block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) | |||
| def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): | |||
| # Rescale the grid of position embeddings when loading from state_dict. Adapted from | |||
| # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 | |||
| ntok_new = posemb_new.shape[1] | |||
| if num_prefix_tokens: | |||
| posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[ | |||
| 0, num_prefix_tokens:] | |||
| ntok_new -= num_prefix_tokens | |||
| else: | |||
| posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] | |||
| gs_old = int(math.sqrt(len(posemb_grid))) | |||
| if not len(gs_new): # backwards compatibility | |||
| gs_new = [int(math.sqrt(ntok_new))] * 2 | |||
| assert len(gs_new) >= 2 | |||
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||
| -1).permute(0, 3, 1, 2) | |||
| posemb_grid = F.interpolate( | |||
| posemb_grid, size=gs_new, mode='bicubic', align_corners=False) | |||
| posemb_grid = posemb_grid.permute(0, 2, 3, | |||
| 1).reshape(1, gs_new[0] * gs_new[1], -1) | |||
| posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) | |||
| return posemb | |||
| def _make_pretrained_clip_vitl16_384(pretrained, | |||
| use_readout='ignore', | |||
| hooks=None, | |||
| enable_attention_hooks=False): | |||
| clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False) | |||
| # model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) | |||
| model = timm.create_model('vit_large_patch16_384', pretrained=False) | |||
| hooks = [5, 11, 17, 23] if hooks is None else hooks | |||
| pretrained = _make_vit_b16_backbone( | |||
| model, | |||
| features=[256, 512, 1024, 1024], | |||
| hooks=hooks, | |||
| vit_features=1024, | |||
| use_readout=use_readout, | |||
| enable_attention_hooks=enable_attention_hooks, | |||
| ) | |||
| return clip_pretrained, pretrained | |||
| def _make_vit_b16_backbone( | |||
| model, | |||
| features=[96, 192, 384, 768], | |||
| size=[384, 384], | |||
| hooks=[2, 5, 8, 11], | |||
| vit_features=768, | |||
| use_readout='ignore', | |||
| start_index=1, | |||
| enable_attention_hooks=False, | |||
| ): | |||
| pretrained = nn.Module() | |||
| pretrained.model = model | |||
| pretrained.model.blocks[hooks[0]].register_forward_hook( | |||
| get_activation('1')) | |||
| pretrained.model.blocks[hooks[1]].register_forward_hook( | |||
| get_activation('2')) | |||
| pretrained.model.blocks[hooks[2]].register_forward_hook( | |||
| get_activation('3')) | |||
| pretrained.model.blocks[hooks[3]].register_forward_hook( | |||
| get_activation('4')) | |||
| pretrained.activations = activations | |||
| if enable_attention_hooks: | |||
| pretrained.model.blocks[hooks[0]].attn.register_forward_hook( | |||
| get_attention('attn_1')) | |||
| pretrained.model.blocks[hooks[1]].attn.register_forward_hook( | |||
| get_attention('attn_2')) | |||
| pretrained.model.blocks[hooks[2]].attn.register_forward_hook( | |||
| get_attention('attn_3')) | |||
| pretrained.model.blocks[hooks[3]].attn.register_forward_hook( | |||
| get_attention('attn_4')) | |||
| pretrained.attention = attention | |||
| readout_oper = get_readout_oper(vit_features, features, use_readout, | |||
| start_index) | |||
| # 32, 48, 136, 384 | |||
| pretrained.act_postprocess1 = nn.Sequential( | |||
| readout_oper[0], | |||
| Transpose(1, 2), | |||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
| nn.Conv2d( | |||
| in_channels=vit_features, | |||
| out_channels=features[0], | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| ), | |||
| nn.ConvTranspose2d( | |||
| in_channels=features[0], | |||
| out_channels=features[0], | |||
| kernel_size=4, | |||
| stride=4, | |||
| padding=0, | |||
| bias=True, | |||
| dilation=1, | |||
| groups=1, | |||
| ), | |||
| ) | |||
| pretrained.act_postprocess2 = nn.Sequential( | |||
| readout_oper[1], | |||
| Transpose(1, 2), | |||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
| nn.Conv2d( | |||
| in_channels=vit_features, | |||
| out_channels=features[1], | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| ), | |||
| nn.ConvTranspose2d( | |||
| in_channels=features[1], | |||
| out_channels=features[1], | |||
| kernel_size=2, | |||
| stride=2, | |||
| padding=0, | |||
| bias=True, | |||
| dilation=1, | |||
| groups=1, | |||
| ), | |||
| ) | |||
| pretrained.act_postprocess3 = nn.Sequential( | |||
| readout_oper[2], | |||
| Transpose(1, 2), | |||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
| nn.Conv2d( | |||
| in_channels=vit_features, | |||
| out_channels=features[2], | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| ), | |||
| ) | |||
| pretrained.act_postprocess4 = nn.Sequential( | |||
| readout_oper[3], | |||
| Transpose(1, 2), | |||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||
| nn.Conv2d( | |||
| in_channels=vit_features, | |||
| out_channels=features[3], | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| ), | |||
| nn.Conv2d( | |||
| in_channels=features[3], | |||
| out_channels=features[3], | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1, | |||
| ), | |||
| ) | |||
| pretrained.model.start_index = start_index | |||
| pretrained.model.patch_size = [16, 16] | |||
| # We inject this function into the VisionTransformer instances so that | |||
| # we can use it with interpolated position embeddings without modifying the library source. | |||
| pretrained.model.forward_flex = types.MethodType(forward_flex, | |||
| pretrained.model) | |||
| pretrained.model._resize_pos_embed = types.MethodType( | |||
| _resize_pos_embed, pretrained.model) | |||
| return pretrained | |||
| @@ -0,0 +1,458 @@ | |||
| """ | |||
| Adapted from https://github.com/isl-org/lang-seg. | |||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||
| """ | |||
| from collections import OrderedDict | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1): | |||
| super().__init__() | |||
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(planes) | |||
| self.relu1 = nn.ReLU(inplace=True) | |||
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||
| self.bn2 = nn.BatchNorm2d(planes) | |||
| self.relu2 = nn.ReLU(inplace=True) | |||
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
| self.relu3 = nn.ReLU(inplace=True) | |||
| self.downsample = None | |||
| self.stride = stride | |||
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||
| self.downsample = nn.Sequential( | |||
| OrderedDict([('-1', nn.AvgPool2d(stride)), | |||
| ('0', | |||
| nn.Conv2d( | |||
| inplanes, | |||
| planes * self.expansion, | |||
| 1, | |||
| stride=1, | |||
| bias=False)), | |||
| ('1', nn.BatchNorm2d(planes * self.expansion))])) | |||
| def forward(self, x: torch.Tensor): | |||
| identity = x | |||
| out = self.relu1(self.bn1(self.conv1(x))) | |||
| out = self.relu2(self.bn2(self.conv2(out))) | |||
| out = self.avgpool(out) | |||
| out = self.bn3(self.conv3(out)) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu3(out) | |||
| return out | |||
| class AttentionPool2d(nn.Module): | |||
| def __init__(self, | |||
| spacial_dim: int, | |||
| embed_dim: int, | |||
| num_heads: int, | |||
| output_dim: int = None): | |||
| super().__init__() | |||
| self.positional_embedding = nn.Parameter( | |||
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||
| self.num_heads = num_heads | |||
| def forward(self, x): | |||
| x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC | |||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||
| x, _ = F.multi_head_attention_forward( | |||
| query=x[:1], | |||
| key=x, | |||
| value=x, | |||
| embed_dim_to_check=x.shape[-1], | |||
| num_heads=self.num_heads, | |||
| q_proj_weight=self.q_proj.weight, | |||
| k_proj_weight=self.k_proj.weight, | |||
| v_proj_weight=self.v_proj.weight, | |||
| in_proj_weight=None, | |||
| in_proj_bias=torch.cat( | |||
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||
| bias_k=None, | |||
| bias_v=None, | |||
| add_zero_attn=False, | |||
| dropout_p=0, | |||
| out_proj_weight=self.c_proj.weight, | |||
| out_proj_bias=self.c_proj.bias, | |||
| use_separate_proj_weight=True, | |||
| training=self.training, | |||
| need_weights=False) | |||
| return x.squeeze(0) | |||
| class ModifiedResNet(nn.Module): | |||
| """ | |||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||
| - The final pooling layer is a QKV attention instead of an average pool | |||
| """ | |||
| def __init__(self, | |||
| layers, | |||
| output_dim, | |||
| heads, | |||
| input_resolution=224, | |||
| width=64): | |||
| super().__init__() | |||
| self.output_dim = output_dim | |||
| self.input_resolution = input_resolution | |||
| # the 3-layer stem | |||
| self.conv1 = nn.Conv2d( | |||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||
| self.relu1 = nn.ReLU(inplace=True) | |||
| self.conv2 = nn.Conv2d( | |||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||
| self.relu2 = nn.ReLU(inplace=True) | |||
| self.conv3 = nn.Conv2d( | |||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d(width) | |||
| self.relu3 = nn.ReLU(inplace=True) | |||
| self.avgpool = nn.AvgPool2d(2) | |||
| # residual layers | |||
| self._inplanes = width # this is a *mutable* variable used during construction | |||
| self.layer1 = self._make_layer(width, layers[0]) | |||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||
| embed_dim = width * 32 # the ResNet feature dimension | |||
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||
| heads, output_dim) | |||
| def _make_layer(self, planes, blocks, stride=1): | |||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||
| self._inplanes = planes * Bottleneck.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append(Bottleneck(self._inplanes, planes)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| def stem(x): | |||
| x = self.relu1(self.bn1(self.conv1(x))) | |||
| x = self.relu2(self.bn2(self.conv2(x))) | |||
| x = self.relu3(self.bn3(self.conv3(x))) | |||
| x = self.avgpool(x) | |||
| return x | |||
| x = x.type(self.conv1.weight.dtype) | |||
| x = stem(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| x = self.attnpool(x) | |||
| return x | |||
| class LayerNorm(nn.LayerNorm): | |||
| """Subclass torch's LayerNorm to handle fp16.""" | |||
| def forward(self, x: torch.Tensor): | |||
| orig_type = x.dtype | |||
| ret = super().forward(x.type(torch.float32)) | |||
| return ret.type(orig_type) | |||
| class QuickGELU(nn.Module): | |||
| def forward(self, x: torch.Tensor): | |||
| return x * torch.sigmoid(1.702 * x) | |||
| class ResidualAttentionBlock(nn.Module): | |||
| def __init__(self, | |||
| d_model: int, | |||
| n_head: int, | |||
| attn_mask: torch.Tensor = None): | |||
| super().__init__() | |||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||
| self.ln_1 = LayerNorm(d_model) | |||
| self.mlp = nn.Sequential( | |||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||
| ('gelu', QuickGELU()), | |||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||
| self.ln_2 = LayerNorm(d_model) | |||
| self.attn_mask = attn_mask | |||
| def attention(self, x: torch.Tensor): | |||
| self.attn_mask = self.attn_mask.to( | |||
| dtype=x.dtype, | |||
| device=x.device) if self.attn_mask is not None else None | |||
| return self.attn( | |||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||
| def forward(self, x: torch.Tensor): | |||
| x = x + self.attention(self.ln_1(x)) | |||
| x = x + self.mlp(self.ln_2(x)) | |||
| return x | |||
| class Transformer(nn.Module): | |||
| def __init__(self, width, layers, heads, attn_mask=None): | |||
| super().__init__() | |||
| self.width = width | |||
| self.layers = layers | |||
| self.resblocks = nn.Sequential(*[ | |||
| ResidualAttentionBlock(width, heads, attn_mask) | |||
| for _ in range(layers) | |||
| ]) | |||
| def forward(self, x: torch.Tensor): | |||
| return self.resblocks(x) | |||
| class VisionTransformer(nn.Module): | |||
| def __init__(self, input_resolution: int, patch_size: int, width: int, | |||
| layers: int, heads: int, output_dim: int): | |||
| super().__init__() | |||
| self.input_resolution = input_resolution | |||
| self.output_dim = output_dim | |||
| self.conv1 = nn.Conv2d( | |||
| in_channels=3, | |||
| out_channels=width, | |||
| kernel_size=patch_size, | |||
| stride=patch_size, | |||
| bias=False) | |||
| scale = width**-0.5 | |||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||
| (input_resolution // patch_size)**2 + 1, width)) | |||
| self.ln_pre = LayerNorm(width) | |||
| self.transformer = Transformer(width, layers, heads) | |||
| self.ln_post = LayerNorm(width) | |||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||
| def forward(self, x: torch.Tensor): | |||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||
| x = x.reshape(x.shape[0], x.shape[1], | |||
| -1) # shape = [*, width, grid ** 2] | |||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||
| x1 = self.class_embedding.to(x.dtype) | |||
| x2 = torch.zeros( | |||
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||
| x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width] | |||
| x = x + self.positional_embedding.to(x.dtype) | |||
| x = self.ln_pre(x) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| x = self.ln_post(x[:, 0, :]) | |||
| if self.proj is not None: | |||
| x = x @ self.proj | |||
| return x | |||
| class CLIP(nn.Module): | |||
| def __init__( | |||
| self, | |||
| embed_dim: int, | |||
| # vision | |||
| image_resolution: int, | |||
| vision_layers: Union[Tuple[int, int, int, int], int], | |||
| vision_width: int, | |||
| vision_patch_size: int, | |||
| # text | |||
| context_length: int, | |||
| vocab_size: int, | |||
| transformer_width: int, | |||
| transformer_heads: int, | |||
| transformer_layers: int): | |||
| super().__init__() | |||
| self.context_length = context_length | |||
| if isinstance(vision_layers, (tuple, list)): | |||
| vision_heads = vision_width * 32 // 64 | |||
| self.visual = ModifiedResNet( | |||
| layers=vision_layers, | |||
| output_dim=embed_dim, | |||
| heads=vision_heads, | |||
| input_resolution=image_resolution, | |||
| width=vision_width) | |||
| else: | |||
| vision_heads = vision_width // 64 | |||
| self.visual = VisionTransformer( | |||
| input_resolution=image_resolution, | |||
| patch_size=vision_patch_size, | |||
| width=vision_width, | |||
| layers=vision_layers, | |||
| heads=vision_heads, | |||
| output_dim=embed_dim) | |||
| self.transformer = Transformer( | |||
| width=transformer_width, | |||
| layers=transformer_layers, | |||
| heads=transformer_heads, | |||
| attn_mask=self.build_attention_mask()) | |||
| self.vocab_size = vocab_size | |||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||
| self.positional_embedding = nn.Parameter( | |||
| torch.empty(self.context_length, transformer_width)) | |||
| self.ln_final = LayerNorm(transformer_width) | |||
| self.text_projection = nn.Parameter( | |||
| torch.empty(transformer_width, embed_dim)) | |||
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |||
| self.initialize_parameters() | |||
| def initialize_parameters(self): | |||
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |||
| nn.init.normal_(self.positional_embedding, std=0.01) | |||
| if isinstance(self.visual, ModifiedResNet): | |||
| if self.visual.attnpool is not None: | |||
| std = self.visual.attnpool.c_proj.in_features**-0.5 | |||
| nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||
| nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||
| for resnet_block in [ | |||
| self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||
| self.visual.layer4 | |||
| ]: | |||
| for name, param in resnet_block.named_parameters(): | |||
| if name.endswith('bn3.weight'): | |||
| nn.init.zeros_(param) | |||
| proj_std = (self.transformer.width**-0.5) * ( | |||
| (2 * self.transformer.layers)**-0.5) | |||
| attn_std = self.transformer.width**-0.5 | |||
| fc_std = (2 * self.transformer.width)**-0.5 | |||
| for block in self.transformer.resblocks: | |||
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||
| if self.text_projection is not None: | |||
| nn.init.normal_( | |||
| self.text_projection, std=self.transformer.width**-0.5) | |||
| def build_attention_mask(self): | |||
| # lazily create causal attention mask, with full attention between the vision tokens | |||
| # pytorch uses additive attention mask; fill with -inf | |||
| mask = torch.empty(self.context_length, self.context_length) | |||
| mask.fill_(float('-inf')) | |||
| mask.triu_(1) # zero out the lower diagonal | |||
| return mask | |||
| @property | |||
| def dtype(self): | |||
| return self.visual.conv1.weight.dtype | |||
| def encode_image(self, image): | |||
| return self.visual(image.type(self.dtype)) | |||
| def encode_text(self, text): | |||
| x = self.token_embedding(text).type(self.dtype) | |||
| x = x + self.positional_embedding.type(self.dtype) | |||
| x = x.permute(1, 0, 2) # NLD -> LND | |||
| x = self.transformer(x) | |||
| x = x.permute(1, 0, 2) # LND -> NLD | |||
| x = self.ln_final(x).type(self.dtype) | |||
| x = x[torch.arange(x.shape[0]), | |||
| text.argmax(dim=-1)] @ self.text_projection | |||
| return x | |||
| def forward(self, image, text): | |||
| image_features = self.encode_image(image) | |||
| text_features = self.encode_text(text) | |||
| # normalized features | |||
| image_features = image_features / image_features.norm( | |||
| dim=1, keepdim=True) | |||
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |||
| # cosine similarity as logits | |||
| logit_scale = self.logit_scale.exp() | |||
| logits_per_image = logit_scale * image_features @ text_features.t() | |||
| logits_per_text = logits_per_image.t() | |||
| # shape = [global_batch_size, global_batch_size] | |||
| return logits_per_image, logits_per_text | |||
| def convert_weights(model: nn.Module): | |||
| """Convert applicable model parameters to fp16""" | |||
| def _convert_weights_to_fp16(ll): | |||
| if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||
| ll.weight.data = ll.weight.data.half() | |||
| if ll.bias is not None: | |||
| ll.bias.data = ll.bias.data.half() | |||
| if isinstance(ll, nn.MultiheadAttention): | |||
| for attr in [ | |||
| *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], | |||
| 'in_proj_bias', 'bias_k', 'bias_v' | |||
| ]: | |||
| tensor = getattr(ll, attr) | |||
| if tensor is not None: | |||
| tensor.data = tensor.data.half() | |||
| for name in ['text_projection', 'proj']: | |||
| if hasattr(ll, name): | |||
| attr = getattr(ll, name) | |||
| if attr is not None: | |||
| attr.data = attr.data.half() | |||
| model.apply(_convert_weights_to_fp16) | |||
| def build_model(): | |||
| model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12) | |||
| convert_weights(model) | |||
| return model.eval() | |||
| @@ -0,0 +1,156 @@ | |||
| """ CLIP | |||
| Adapted from https://github.com/openai/CLIP. | |||
| Originally MIT License, Copyright (c) 2021 OpenAI. | |||
| """ | |||
| import gzip | |||
| import html | |||
| import os | |||
| from functools import lru_cache | |||
| import ftfy | |||
| import regex as re | |||
| @lru_cache() | |||
| def default_bpe(): | |||
| return os.path.join( | |||
| os.path.dirname(os.path.abspath(__file__)), | |||
| 'bpe_simple_vocab_16e6.txt.gz') | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||
| The reversible bpe codes work on unicode strings. | |||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||
| """ | |||
| bs = list(range(ord('!'), | |||
| ord('~') + 1)) + list(range( | |||
| ord('¡'), | |||
| ord('¬') + 1)) + list(range(ord('®'), | |||
| ord('ÿ') + 1)) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2**8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2**8 + n) | |||
| n += 1 | |||
| cs = [chr(n) for n in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| def basic_clean(text): | |||
| text = ftfy.fix_text(text) | |||
| text = html.unescape(html.unescape(text)) | |||
| return text.strip() | |||
| def whitespace_clean(text): | |||
| text = re.sub(r'\s+', ' ', text) | |||
| text = text.strip() | |||
| return text | |||
| class SimpleTokenizer(object): | |||
| def __init__(self, bpe_path: str = default_bpe()): | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||
| merges = merges[1:49152 - 256 - 2 + 1] | |||
| merges = [tuple(merge.split()) for merge in merges] | |||
| vocab = list(bytes_to_unicode().values()) | |||
| vocab = vocab + [v + '</w>' for v in vocab] | |||
| for merge in merges: | |||
| vocab.append(''.join(merge)) | |||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||
| self.cache = { | |||
| '<|startoftext|>': '<|startoftext|>', | |||
| '<|endoftext|>': '<|endoftext|>' | |||
| } | |||
| self.pat = re.compile( | |||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||
| re.IGNORECASE) | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token + '</w>' | |||
| while True: | |||
| bigram = min( | |||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| error_list = [] | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except Exception as err: | |||
| new_word.extend(word[i:]) | |||
| error_list.append(err) | |||
| break | |||
| if word[i] == first and i < len(word) - 1 and word[ | |||
| i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = ' '.join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def encode(self, text): | |||
| bpe_tokens = [] | |||
| text = whitespace_clean(basic_clean(text)).lower() | |||
| for token in re.findall(self.pat, text): | |||
| token = ''.join(self.byte_encoder[b] | |||
| for b in token.encode('utf-8')) | |||
| bpe_tokens.extend(self.encoder[bpe_token] | |||
| for bpe_token in self.bpe(token).split(' ')) | |||
| return bpe_tokens | |||
| def decode(self, tokens): | |||
| text = ''.join([self.decoder[token] for token in tokens]) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||
| return text | |||
| @@ -243,6 +243,13 @@ TASK_OUTPUTS = { | |||
| # "output_img": np.ndarray with shape [height, width, 3] | |||
| # } | |||
| Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], | |||
| # text driven segmentation result for single sample | |||
| # { | |||
| # "masks": [ | |||
| # np.array # 2D array containing only 0, 255 | |||
| # ] | |||
| # } | |||
| Tasks.text_driven_segmentation: [OutputKeys.MASKS], | |||
| # movide scene segmentation result for a single video | |||
| # { | |||
| @@ -149,6 +149,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_vitb_video-single-object-tracking_ostrack'), | |||
| Tasks.image_reid_person: (Pipelines.image_reid_person, | |||
| 'damo/cv_passvitb_image-reid-person_market'), | |||
| Tasks.text_driven_segmentation: | |||
| (Pipelines.text_driven_segmentation, | |||
| 'damo/cv_vitl16_segmentation_text-driven-seg'), | |||
| Tasks.movie_scene_segmentation: | |||
| (Pipelines.movie_scene_segmentation, | |||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | |||
| @@ -44,6 +44,7 @@ if TYPE_CHECKING: | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | |||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | |||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | |||
| else: | |||
| @@ -97,6 +98,8 @@ else: | |||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
| 'easycv_pipeline': | |||
| ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], | |||
| 'text_driven_segmentation_pipeline': | |||
| ['TextDrivenSegmentationPipeline'], | |||
| 'movie_scene_segmentation_pipeline': | |||
| ['MovieSceneSegmentationPipeline'], | |||
| } | |||
| @@ -0,0 +1,51 @@ | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import Tasks | |||
| @PIPELINES.register_module( | |||
| Tasks.text_driven_segmentation, | |||
| module_name=Pipelines.text_driven_segmentation) | |||
| class TextDrivenSegmentationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||
| def preprocess(self, input: Dict) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input['image']) | |||
| img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) | |||
| result = { | |||
| 'img': img_tensor, | |||
| 'ori_h': ori_h, | |||
| 'ori_w': ori_w, | |||
| 'crop_h': crop_h, | |||
| 'crop_w': crop_w, | |||
| 'text': input['text'], | |||
| } | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| outputs = self.model.inference(input['img'], input['text']) | |||
| result = { | |||
| 'data': outputs, | |||
| 'ori_h': input['ori_h'], | |||
| 'ori_w': input['ori_w'], | |||
| 'crop_h': input['crop_h'], | |||
| 'crop_w': input['crop_w'], | |||
| } | |||
| return result | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| data = self.model.postprocess(inputs['data'], inputs['crop_h'], | |||
| inputs['crop_w'], inputs['ori_h'], | |||
| inputs['ori_w']) | |||
| outputs = {OutputKeys.MASKS: data} | |||
| return outputs | |||
| @@ -36,6 +36,7 @@ class CVTasks(object): | |||
| image_segmentation = 'image-segmentation' | |||
| portrait_matting = 'portrait-matting' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| # image editing | |||
| skin_retouching = 'skin-retouching' | |||
| @@ -0,0 +1,28 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TextDrivenSegmentationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_text_driven_segmentation(self): | |||
| input_location = 'data/test/images/text_driven_segmentation.jpg' | |||
| test_input = { | |||
| 'image': input_location, | |||
| 'text': 'bear', | |||
| } | |||
| model_id = 'damo/cv_vitl16_segmentation_text-driven-seg' | |||
| shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id) | |||
| result = shop_seg(test_input) | |||
| import cv2 | |||
| # result[OutputKeys.MASKS] is segment map result,other keys are not used | |||
| cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||