文本指导的语义分割模型,根据输入的文本信息,讲图像中对应文本描述的物体分割出来。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942863
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4 | |||||
| size 67861 | |||||
| @@ -29,6 +29,7 @@ class Models(object): | |||||
| video_summarization = 'pgl-video-summarization' | video_summarization = 'pgl-video-summarization' | ||||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | swinL_semantic_segmentation = 'swinL-semantic-segmentation' | ||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | |||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| # EasyCV models | # EasyCV models | ||||
| @@ -143,6 +144,7 @@ class Pipelines(object): | |||||
| video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
| image_semantic_segmentation = 'image-semantic-segmentation' | image_semantic_segmentation = 'image-semantic-segmentation' | ||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | |||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| # nlp tasks | # nlp tasks | ||||
| @@ -0,0 +1 @@ | |||||
| from .lseg_base import TextDrivenSegmentation | |||||
| @@ -0,0 +1,170 @@ | |||||
| """ CLIP | |||||
| Adapted from https://github.com/openai/CLIP. | |||||
| Originally MIT License, Copyright (c) 2021 OpenAI. | |||||
| """ | |||||
| import hashlib | |||||
| import os | |||||
| import urllib | |||||
| import warnings | |||||
| from typing import Any, List, Union | |||||
| import torch | |||||
| from PIL import Image | |||||
| from pkg_resources import packaging | |||||
| from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, | |||||
| ToTensor) | |||||
| from tqdm import tqdm | |||||
| from .model import build_model | |||||
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |||||
| try: | |||||
| from torchvision.transforms import InterpolationMode | |||||
| BICUBIC = InterpolationMode.BICUBIC | |||||
| except ImportError: | |||||
| BICUBIC = Image.BICUBIC | |||||
| if packaging.version.parse( | |||||
| torch.__version__) < packaging.version.parse('1.7.1'): | |||||
| warnings.warn('PyTorch version 1.7.1 or higher is recommended') | |||||
| __all__ = ['load', 'tokenize'] | |||||
| def _convert_image_to_rgb(image): | |||||
| return image.convert('RGB') | |||||
| def _transform(n_px): | |||||
| return Compose([ | |||||
| Resize(n_px, interpolation=BICUBIC), | |||||
| CenterCrop(n_px), | |||||
| _convert_image_to_rgb, | |||||
| ToTensor(), | |||||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711)), | |||||
| ]) | |||||
| def load(name: str, | |||||
| device: Union[str, torch.device] = 'cuda' | |||||
| if torch.cuda.is_available() else 'cpu', | |||||
| jit: bool = False, | |||||
| root: str = None): | |||||
| if not jit: | |||||
| model = build_model().to(device) | |||||
| if str(device) == 'cpu': | |||||
| model.float() | |||||
| return model, _transform(model.visual.input_resolution) | |||||
| # patch the device names | |||||
| device_holder = torch.jit.trace( | |||||
| lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | |||||
| device_node = [ | |||||
| n for n in device_holder.graph.findAllNodes('prim::Constant') | |||||
| if 'Device' in repr(n) | |||||
| ][-1] | |||||
| def patch_device(module): | |||||
| try: | |||||
| graphs = [module.graph] if hasattr(module, 'graph') else [] | |||||
| except RuntimeError: | |||||
| graphs = [] | |||||
| if hasattr(module, 'forward1'): | |||||
| graphs.append(module.forward1.graph) | |||||
| for graph in graphs: | |||||
| for node in graph.findAllNodes('prim::Constant'): | |||||
| if 'value' in node.attributeNames() and str( | |||||
| node['value']).startswith('cuda'): | |||||
| node.copyAttributes(device_node) | |||||
| model.apply(patch_device) | |||||
| patch_device(model.encode_image) | |||||
| patch_device(model.encode_text) | |||||
| # patch dtype to float32 on CPU | |||||
| if str(device) == 'cpu': | |||||
| float_holder = torch.jit.trace( | |||||
| lambda: torch.ones([]).float(), example_inputs=[]) | |||||
| float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] | |||||
| float_node = float_input.node() | |||||
| def patch_float(module): | |||||
| try: | |||||
| graphs = [module.graph] if hasattr(module, 'graph') else [] | |||||
| except RuntimeError: | |||||
| graphs = [] | |||||
| if hasattr(module, 'forward1'): | |||||
| graphs.append(module.forward1.graph) | |||||
| for graph in graphs: | |||||
| for node in graph.findAllNodes('aten::to'): | |||||
| inputs = list(node.inputs()) | |||||
| for i in [ | |||||
| 1, 2 | |||||
| ]: # dtype can be the second or third argument to aten::to() | |||||
| if inputs[i].node()['value'] == 5: | |||||
| inputs[i].node().copyAttributes(float_node) | |||||
| model.apply(patch_float) | |||||
| patch_float(model.encode_image) | |||||
| patch_float(model.encode_text) | |||||
| model.float() | |||||
| return model, _transform(model.input_resolution.item()) | |||||
| def tokenize( | |||||
| _tokenizer, | |||||
| texts: Union[str, List[str]], | |||||
| context_length: int = 77, | |||||
| truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: | |||||
| """ | |||||
| Returns the tokenized representation of given input string(s) | |||||
| Parameters | |||||
| ---------- | |||||
| texts : Union[str, List[str]] | |||||
| An input string or a list of input strings to tokenize | |||||
| context_length : int | |||||
| The context length to use; all CLIP models use 77 as the context length | |||||
| truncate: bool | |||||
| Whether to truncate the text in case its encoding is longer than the context length | |||||
| Returns | |||||
| ------- | |||||
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |||||
| We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |||||
| """ | |||||
| if isinstance(texts, str): | |||||
| texts = [texts] | |||||
| sot_token = _tokenizer.encoder['<|startoftext|>'] | |||||
| eot_token = _tokenizer.encoder['<|endoftext|>'] | |||||
| all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] | |||||
| for text in texts] | |||||
| if packaging.version.parse( | |||||
| torch.__version__) < packaging.version.parse('1.8.0'): | |||||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |||||
| else: | |||||
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |||||
| for i, tokens in enumerate(all_tokens): | |||||
| if len(tokens) > context_length: | |||||
| if truncate: | |||||
| tokens = tokens[:context_length] | |||||
| tokens[-1] = eot_token | |||||
| else: | |||||
| raise RuntimeError( | |||||
| f'Input {texts[i]} is too long for context length {context_length}' | |||||
| ) | |||||
| result[i, :len(tokens)] = torch.tensor(tokens) | |||||
| return result | |||||
| @@ -0,0 +1,28 @@ | |||||
| """ | |||||
| Adapted from https://github.com/isl-org/lang-seg. | |||||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .lseg_net import LSeg | |||||
| class TextDrivenSegmentation(nn.Module): | |||||
| def __init__(self, model_dir): | |||||
| super(TextDrivenSegmentation, self).__init__() | |||||
| self.net = LSeg(model_dir=model_dir) | |||||
| self.model_dir = model_dir | |||||
| def forward(self, img, txt_list): | |||||
| b = img.size()[0] | |||||
| batch_name_list = txt_list | |||||
| xout_list = [] | |||||
| for i in range(b): | |||||
| labelset = ['others', batch_name_list[i]] | |||||
| xout = self.net(img[i:i + 1], labelset=labelset) | |||||
| xout_list.append(xout) | |||||
| score_map = torch.cat(xout_list, dim=0) | |||||
| return score_map | |||||
| @@ -0,0 +1,334 @@ | |||||
| """ | |||||
| Adapted from https://github.com/isl-org/lang-seg. | |||||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit | |||||
| def _make_encoder( | |||||
| backbone, | |||||
| features, | |||||
| use_pretrained=True, | |||||
| groups=1, | |||||
| expand=False, | |||||
| exportable=True, | |||||
| hooks=None, | |||||
| use_vit_only=False, | |||||
| use_readout='ignore', | |||||
| enable_attention_hooks=False, | |||||
| ): | |||||
| if backbone == 'clip_vitl16_384': | |||||
| clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( | |||||
| use_pretrained, | |||||
| hooks=hooks, | |||||
| use_readout=use_readout, | |||||
| enable_attention_hooks=enable_attention_hooks, | |||||
| ) | |||||
| scratch = _make_scratch([256, 512, 1024, 1024], | |||||
| features, | |||||
| groups=groups, | |||||
| expand=expand) | |||||
| else: | |||||
| raise NotImplementedError(f"Backbone '{backbone}' not implemented") | |||||
| return clip_pretrained, pretrained, scratch | |||||
| def _make_scratch(in_shape, out_shape, groups=1, expand=False): | |||||
| scratch = nn.Module() | |||||
| out_shape1 = out_shape | |||||
| out_shape2 = out_shape | |||||
| out_shape3 = out_shape | |||||
| out_shape4 = out_shape | |||||
| if expand is True: | |||||
| out_shape1 = out_shape | |||||
| out_shape2 = out_shape * 2 | |||||
| out_shape3 = out_shape * 4 | |||||
| out_shape4 = out_shape * 8 | |||||
| scratch.layer1_rn = nn.Conv2d( | |||||
| in_shape[0], | |||||
| out_shape1, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=False, | |||||
| groups=groups, | |||||
| ) | |||||
| scratch.layer2_rn = nn.Conv2d( | |||||
| in_shape[1], | |||||
| out_shape2, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=False, | |||||
| groups=groups, | |||||
| ) | |||||
| scratch.layer3_rn = nn.Conv2d( | |||||
| in_shape[2], | |||||
| out_shape3, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=False, | |||||
| groups=groups, | |||||
| ) | |||||
| scratch.layer4_rn = nn.Conv2d( | |||||
| in_shape[3], | |||||
| out_shape4, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=False, | |||||
| groups=groups, | |||||
| ) | |||||
| return scratch | |||||
| class Interpolate(nn.Module): | |||||
| """Interpolation module.""" | |||||
| def __init__(self, scale_factor, mode, align_corners=False): | |||||
| """Init. | |||||
| Args: | |||||
| scale_factor (float): scaling | |||||
| mode (str): interpolation mode | |||||
| """ | |||||
| super(Interpolate, self).__init__() | |||||
| self.interp = nn.functional.interpolate | |||||
| self.scale_factor = scale_factor | |||||
| self.mode = mode | |||||
| self.align_corners = align_corners | |||||
| def forward(self, x): | |||||
| """Forward pass. | |||||
| Args: | |||||
| x (tensor): input | |||||
| Returns: | |||||
| tensor: interpolated data | |||||
| """ | |||||
| x = self.interp( | |||||
| x, | |||||
| scale_factor=self.scale_factor, | |||||
| mode=self.mode, | |||||
| align_corners=self.align_corners, | |||||
| ) | |||||
| return x | |||||
| class ResidualConvUnit(nn.Module): | |||||
| """Residual convolution module.""" | |||||
| def __init__(self, features): | |||||
| """Init. | |||||
| Args: | |||||
| features (int): number of features | |||||
| """ | |||||
| super().__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||||
| self.conv2 = nn.Conv2d( | |||||
| features, features, kernel_size=3, stride=1, padding=1, bias=True) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| def forward(self, x): | |||||
| """Forward pass. | |||||
| Args: | |||||
| x (tensor): input | |||||
| Returns: | |||||
| tensor: output | |||||
| """ | |||||
| out = self.relu(x) | |||||
| out = self.conv1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| return out + x | |||||
| class FeatureFusionBlock(nn.Module): | |||||
| """Feature fusion block.""" | |||||
| def __init__(self, features): | |||||
| """Init. | |||||
| Args: | |||||
| features (int): number of features | |||||
| """ | |||||
| super(FeatureFusionBlock, self).__init__() | |||||
| self.resConfUnit1 = ResidualConvUnit(features) | |||||
| self.resConfUnit2 = ResidualConvUnit(features) | |||||
| def forward(self, *xs): | |||||
| """Forward pass. | |||||
| Returns: | |||||
| tensor: output | |||||
| """ | |||||
| output = xs[0] | |||||
| if len(xs) == 2: | |||||
| output += self.resConfUnit1(xs[1]) | |||||
| output = self.resConfUnit2(output) | |||||
| output = nn.functional.interpolate( | |||||
| output, scale_factor=2, mode='bilinear', align_corners=True) | |||||
| return output | |||||
| class ResidualConvUnit_custom(nn.Module): | |||||
| """Residual convolution module.""" | |||||
| def __init__(self, features, activation, bn): | |||||
| """Init. | |||||
| Args: | |||||
| features (int): number of features | |||||
| """ | |||||
| super().__init__() | |||||
| self.bn = bn | |||||
| self.groups = 1 | |||||
| self.conv1 = nn.Conv2d( | |||||
| features, | |||||
| features, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=not self.bn, | |||||
| groups=self.groups, | |||||
| ) | |||||
| self.conv2 = nn.Conv2d( | |||||
| features, | |||||
| features, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| bias=not self.bn, | |||||
| groups=self.groups, | |||||
| ) | |||||
| if self.bn is True: | |||||
| self.bn1 = nn.BatchNorm2d(features) | |||||
| self.bn2 = nn.BatchNorm2d(features) | |||||
| self.activation = activation | |||||
| self.skip_add = nn.quantized.FloatFunctional() | |||||
| def forward(self, x): | |||||
| """Forward pass. | |||||
| Args: | |||||
| x (tensor): input | |||||
| Returns: | |||||
| tensor: output | |||||
| """ | |||||
| out = self.activation(x) | |||||
| out = self.conv1(out) | |||||
| if self.bn is True: | |||||
| out = self.bn1(out) | |||||
| out = self.activation(out) | |||||
| out = self.conv2(out) | |||||
| if self.bn is True: | |||||
| out = self.bn2(out) | |||||
| if self.groups > 1: | |||||
| out = self.conv_merge(out) | |||||
| return self.skip_add.add(out, x) | |||||
| class FeatureFusionBlock_custom(nn.Module): | |||||
| """Feature fusion block.""" | |||||
| def __init__( | |||||
| self, | |||||
| features, | |||||
| activation, | |||||
| deconv=False, | |||||
| bn=False, | |||||
| expand=False, | |||||
| align_corners=True, | |||||
| ): | |||||
| """Init. | |||||
| Args: | |||||
| features (int): number of features | |||||
| """ | |||||
| super(FeatureFusionBlock_custom, self).__init__() | |||||
| self.deconv = deconv | |||||
| self.align_corners = align_corners | |||||
| self.groups = 1 | |||||
| self.expand = expand | |||||
| out_features = features | |||||
| if self.expand is True: | |||||
| out_features = features // 2 | |||||
| self.out_conv = nn.Conv2d( | |||||
| features, | |||||
| out_features, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| bias=True, | |||||
| groups=1, | |||||
| ) | |||||
| self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) | |||||
| self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) | |||||
| self.skip_add = nn.quantized.FloatFunctional() | |||||
| def forward(self, *xs): | |||||
| """Forward pass. | |||||
| Returns: | |||||
| tensor: output | |||||
| """ | |||||
| output = xs[0] | |||||
| if len(xs) == 2: | |||||
| res = self.resConfUnit1(xs[1]) | |||||
| output = self.skip_add.add(output, res) | |||||
| output = self.resConfUnit2(output) | |||||
| output = nn.functional.interpolate( | |||||
| output, | |||||
| scale_factor=2, | |||||
| mode='bilinear', | |||||
| align_corners=self.align_corners) | |||||
| output = self.out_conv(output) | |||||
| return output | |||||
| @@ -0,0 +1,107 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.text_driven_segmentation import \ | |||||
| TextDrivenSegmentation | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['TextDrivenSeg'] | |||||
| @MODELS.register_module( | |||||
| Tasks.text_driven_segmentation, | |||||
| module_name=Models.text_driven_segmentation) | |||||
| class TextDrivenSeg(TorchModel): | |||||
| """ text driven segmentation model. | |||||
| """ | |||||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
| super().__init__( | |||||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
| self.model = TextDrivenSegmentation(model_dir=model_dir) | |||||
| pretrained_params = torch.load('{}/{}'.format( | |||||
| model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
| self.model.load_state_dict(pretrained_params) | |||||
| self.model.eval() | |||||
| if device_id >= 0 and torch.cuda.is_available(): | |||||
| self.model.to('cuda:{}'.format(device_id)) | |||||
| logger.info('Use GPU: {}'.format(device_id)) | |||||
| else: | |||||
| device_id = -1 | |||||
| logger.info('Use CPU for inference') | |||||
| self.device_id = device_id | |||||
| def preprocess(self, img, size=640): | |||||
| mean = [0.48145466, 0.4578275, 0.40821073] | |||||
| std = [0.26862954, 0.26130258, 0.27577711] | |||||
| h, w, c = img.shape | |||||
| max_hw = max(h, w) | |||||
| ratio = 1.0 * size / max_hw | |||||
| crop_h, crop_w = int(ratio * h), int(ratio * w) | |||||
| pil_img = Image.fromarray(img) | |||||
| pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) | |||||
| np_img = np.array(pil_img, dtype=np.float32) / 255. | |||||
| for j in range(3): | |||||
| np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] | |||||
| img_pad = np.zeros((size, size, 3), dtype=np.float32) | |||||
| img_pad[:crop_h, :crop_w] = np_img | |||||
| img_pad = torch.from_numpy(img_pad).permute(2, 0, | |||||
| 1).unsqueeze(0).float() | |||||
| return img_pad, h, w, crop_h, crop_w | |||||
| def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): | |||||
| output = np.clip(tensors * 255., a_min=0, a_max=255.) | |||||
| crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) | |||||
| pil_output = Image.fromarray(crop_output) | |||||
| pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) | |||||
| np_output = np.array(pil_output, dtype=np.uint8) | |||||
| np_output[np_output < 128] = 0 | |||||
| np_output[np_output >= 128] = 255 | |||||
| np_output = np.uint8(np_output) | |||||
| return np_output | |||||
| def forward(self, image, text): | |||||
| """ | |||||
| image should be numpy array, dtype=np.uint8, shape: height*width*3 | |||||
| """ | |||||
| image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( | |||||
| image, size=640) | |||||
| pred = self.inference(image_tensor, text) | |||||
| msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640) | |||||
| outputs = {OutputKeys.MASKS: msk} | |||||
| return outputs | |||||
| def inference(self, image, text): | |||||
| """ | |||||
| image should be tensor, 1 * 3 * 640 * 640 | |||||
| """ | |||||
| with torch.no_grad(): | |||||
| if self.device_id == -1: | |||||
| output = self.model(image) | |||||
| else: | |||||
| device = torch.device('cuda', self.device_id) | |||||
| output = self.model(image.to(device), [text]) | |||||
| output = F.interpolate(output, size=(640, 640), mode='bilinear') | |||||
| output = F.softmax(output, dim=1) | |||||
| output = torch.argmax(output, dim=1) | |||||
| output = output[0] | |||||
| if self.device_id == -1: | |||||
| pred = output.data.numpy() | |||||
| else: | |||||
| pred = output.data.cpu().numpy() | |||||
| del output | |||||
| return pred | |||||
| @@ -0,0 +1,197 @@ | |||||
| """ | |||||
| Adapted from https://github.com/isl-org/lang-seg. | |||||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from . import clip | |||||
| from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom, | |||||
| Interpolate, _make_encoder, forward_vit) | |||||
| from .simple_tokenizer import SimpleTokenizer | |||||
| class depthwise_clipseg_conv(nn.Module): | |||||
| def __init__(self): | |||||
| super(depthwise_clipseg_conv, self).__init__() | |||||
| self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) | |||||
| def depthwise_clipseg(self, x, channels): | |||||
| x = torch.cat( | |||||
| [self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], | |||||
| dim=1) | |||||
| return x | |||||
| def forward(self, x): | |||||
| channels = x.shape[1] | |||||
| out = self.depthwise_clipseg(x, channels) | |||||
| return out | |||||
| class depthwise_conv(nn.Module): | |||||
| def __init__(self, kernel_size=3, stride=1, padding=1): | |||||
| super(depthwise_conv, self).__init__() | |||||
| self.depthwise = nn.Conv2d( | |||||
| 1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | |||||
| def forward(self, x): | |||||
| # support for 4D tensor with NCHW | |||||
| C, H, W = x.shape[1:] | |||||
| x = x.reshape(-1, 1, H, W) | |||||
| x = self.depthwise(x) | |||||
| x = x.view(-1, C, H, W) | |||||
| return x | |||||
| class depthwise_block(nn.Module): | |||||
| def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||||
| super(depthwise_block, self).__init__() | |||||
| self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||||
| if activation == 'relu': | |||||
| self.activation = nn.ReLU() | |||||
| elif activation == 'lrelu': | |||||
| self.activation = nn.LeakyReLU() | |||||
| elif activation == 'tanh': | |||||
| self.activation = nn.Tanh() | |||||
| def forward(self, x, act=True): | |||||
| x = self.depthwise(x) | |||||
| if act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| class bottleneck_block(nn.Module): | |||||
| def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): | |||||
| super(bottleneck_block, self).__init__() | |||||
| self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) | |||||
| if activation == 'relu': | |||||
| self.activation = nn.ReLU() | |||||
| elif activation == 'lrelu': | |||||
| self.activation = nn.LeakyReLU() | |||||
| elif activation == 'tanh': | |||||
| self.activation = nn.Tanh() | |||||
| def forward(self, x, act=True): | |||||
| sum_layer = x.max(dim=1, keepdim=True)[0] | |||||
| x = self.depthwise(x) | |||||
| x = x + sum_layer | |||||
| if act: | |||||
| x = self.activation(x) | |||||
| return x | |||||
| class BaseModel(torch.nn.Module): | |||||
| def load(self, path): | |||||
| """Load model from file. | |||||
| Args: | |||||
| path (str): file path | |||||
| """ | |||||
| parameters = torch.load(path, map_location=torch.device('cpu')) | |||||
| if 'optimizer' in parameters: | |||||
| parameters = parameters['model'] | |||||
| self.load_state_dict(parameters) | |||||
| def _make_fusion_block(features, use_bn): | |||||
| return FeatureFusionBlock_custom( | |||||
| features, | |||||
| activation=nn.ReLU(False), | |||||
| deconv=False, | |||||
| bn=use_bn, | |||||
| expand=False, | |||||
| align_corners=True, | |||||
| ) | |||||
| class LSeg(BaseModel): | |||||
| def __init__( | |||||
| self, | |||||
| features=256, | |||||
| backbone='clip_vitl16_384', | |||||
| readout='project', | |||||
| use_bn=True, | |||||
| model_dir=None, | |||||
| ): | |||||
| super(LSeg, self).__init__() | |||||
| hooks = { | |||||
| 'clip_vitl16_384': [5, 11, 17, 23], | |||||
| } | |||||
| # Instantiate backbone and reassemble blocks | |||||
| self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( | |||||
| backbone, | |||||
| features, | |||||
| groups=1, | |||||
| expand=False, | |||||
| exportable=False, | |||||
| hooks=hooks[backbone], | |||||
| use_readout=readout, | |||||
| ) | |||||
| self.scratch.refinenet1 = _make_fusion_block(features, use_bn) | |||||
| self.scratch.refinenet2 = _make_fusion_block(features, use_bn) | |||||
| self.scratch.refinenet3 = _make_fusion_block(features, use_bn) | |||||
| self.scratch.refinenet4 = _make_fusion_block(features, use_bn) | |||||
| self.logit_scale = nn.Parameter(torch.ones([]) | |||||
| * np.log(1 / 0.07)).exp() | |||||
| self.out_c = 512 | |||||
| self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) | |||||
| self.scratch.output_conv = nn.Sequential( | |||||
| Interpolate(scale_factor=2, mode='bilinear', align_corners=True), ) | |||||
| self.tau = 0.07 | |||||
| self.model_dir = model_dir | |||||
| self.tokenizer = SimpleTokenizer(model_dir | |||||
| + '/bpe_simple_vocab_16e6.txt.gz') | |||||
| def forward(self, x, labelset=''): | |||||
| text = clip.tokenize(self.tokenizer, labelset) | |||||
| layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) | |||||
| layer_1_rn = self.scratch.layer1_rn(layer_1) | |||||
| layer_2_rn = self.scratch.layer2_rn(layer_2) | |||||
| layer_3_rn = self.scratch.layer3_rn(layer_3) | |||||
| layer_4_rn = self.scratch.layer4_rn(layer_4) | |||||
| path_4 = self.scratch.refinenet4(layer_4_rn) | |||||
| path_3 = self.scratch.refinenet3(path_4, layer_3_rn) | |||||
| path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | |||||
| path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | |||||
| text = text.to(x.device) | |||||
| text_features = self.clip_pretrained.encode_text(text) | |||||
| image_features = self.scratch.head1(path_1) | |||||
| imshape = image_features.shape | |||||
| image_features = image_features.permute(0, 2, 3, | |||||
| 1).reshape(-1, self.out_c) | |||||
| # normalized features | |||||
| image_features = image_features / image_features.norm( | |||||
| dim=-1, keepdim=True) | |||||
| text_features = text_features / text_features.norm( | |||||
| dim=-1, keepdim=True) | |||||
| logits_per_image = image_features @ text_features.t() / self.tau | |||||
| out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], | |||||
| -1).permute(0, 3, 1, 2) | |||||
| out = self.scratch.output_conv(out) | |||||
| return out | |||||
| @@ -0,0 +1,543 @@ | |||||
| """ | |||||
| Adapted from https://github.com/isl-org/lang-seg. | |||||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||||
| """ | |||||
| import math | |||||
| import types | |||||
| import timm | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint as checkpoint | |||||
| from . import clip | |||||
| activations = {} | |||||
| def get_activation(name): | |||||
| def hook(model, input, output): | |||||
| activations[name] = output | |||||
| return hook | |||||
| attention = {} | |||||
| def get_attention(name): | |||||
| def hook(module, input, output): | |||||
| x = input[0] | |||||
| B, N, C = x.shape | |||||
| qkv = ( | |||||
| module.qkv(x).reshape(B, N, 3, module.num_heads, | |||||
| C // module.num_heads).permute( | |||||
| 2, 0, 3, 1, 4)) | |||||
| q, k, _ = ( | |||||
| qkv[0], | |||||
| qkv[1], | |||||
| qkv[2], | |||||
| ) # make torchscript happy (cannot use tensor as tuple) | |||||
| attn = (q @ k.transpose(-2, -1)) * module.scale | |||||
| attn = attn.softmax(dim=-1) # [:,:,1,1:] | |||||
| attention[name] = attn | |||||
| return hook | |||||
| def get_mean_attention_map(attn, token, shape): | |||||
| attn = attn[:, :, token, 1:] | |||||
| attn = attn.unflatten(2, torch.Size([shape[2] // 16, | |||||
| shape[3] // 16])).float() | |||||
| attn = torch.nn.functional.interpolate( | |||||
| attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0) | |||||
| all_attn = torch.mean(attn, 0) | |||||
| return all_attn | |||||
| class Slice(nn.Module): | |||||
| def __init__(self, start_index=1): | |||||
| super(Slice, self).__init__() | |||||
| self.start_index = start_index | |||||
| def forward(self, x): | |||||
| return x[:, self.start_index:] | |||||
| class AddReadout(nn.Module): | |||||
| def __init__(self, start_index=1): | |||||
| super(AddReadout, self).__init__() | |||||
| self.start_index = start_index | |||||
| def forward(self, x): | |||||
| if self.start_index == 2: | |||||
| readout = (x[:, 0] + x[:, 1]) / 2 | |||||
| else: | |||||
| readout = x[:, 0] | |||||
| return x[:, self.start_index:] + readout.unsqueeze(1) | |||||
| class ProjectReadout(nn.Module): | |||||
| def __init__(self, in_features, start_index=1): | |||||
| super(ProjectReadout, self).__init__() | |||||
| self.start_index = start_index | |||||
| self.project = nn.Sequential( | |||||
| nn.Linear(2 * in_features, in_features), nn.GELU()) | |||||
| def forward(self, x): | |||||
| readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) | |||||
| features = torch.cat((x[:, self.start_index:], readout), -1) | |||||
| return self.project(features) | |||||
| class Transpose(nn.Module): | |||||
| def __init__(self, dim0, dim1): | |||||
| super(Transpose, self).__init__() | |||||
| self.dim0 = dim0 | |||||
| self.dim1 = dim1 | |||||
| def forward(self, x): | |||||
| x = x.transpose(self.dim0, self.dim1) | |||||
| return x | |||||
| def forward_vit(pretrained, x): | |||||
| b, c, h, w = x.shape | |||||
| # encoder | |||||
| _ = pretrained.model.forward_flex(x) | |||||
| layer_1 = pretrained.activations['1'] | |||||
| layer_2 = pretrained.activations['2'] | |||||
| layer_3 = pretrained.activations['3'] | |||||
| layer_4 = pretrained.activations['4'] | |||||
| layer_1 = pretrained.act_postprocess1[0:2](layer_1) | |||||
| layer_2 = pretrained.act_postprocess2[0:2](layer_2) | |||||
| layer_3 = pretrained.act_postprocess3[0:2](layer_3) | |||||
| layer_4 = pretrained.act_postprocess4[0:2](layer_4) | |||||
| unflatten = nn.Sequential( | |||||
| nn.Unflatten( | |||||
| 2, | |||||
| torch.Size([ | |||||
| h // pretrained.model.patch_size[1], | |||||
| w // pretrained.model.patch_size[0], | |||||
| ]), | |||||
| )) | |||||
| if layer_1.ndim == 3: | |||||
| layer_1 = unflatten(layer_1) | |||||
| if layer_2.ndim == 3: | |||||
| layer_2 = unflatten(layer_2) | |||||
| if layer_3.ndim == 3: | |||||
| layer_3 = unflatten(layer_3) | |||||
| if layer_4.ndim == 3: | |||||
| layer_4 = unflatten(layer_4) | |||||
| layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)]( | |||||
| layer_1) | |||||
| layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)]( | |||||
| layer_2) | |||||
| layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)]( | |||||
| layer_3) | |||||
| layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)]( | |||||
| layer_4) | |||||
| return layer_1, layer_2, layer_3, layer_4 | |||||
| def _resize_pos_embed(self, posemb, gs_h, gs_w): | |||||
| posemb_tok, posemb_grid = ( | |||||
| posemb[:, :self.start_index], | |||||
| posemb[0, self.start_index:], | |||||
| ) | |||||
| gs_old = int(math.sqrt(len(posemb_grid))) | |||||
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||||
| -1).permute(0, 3, 1, 2) | |||||
| posemb_grid = F.interpolate( | |||||
| posemb_grid, size=(gs_h, gs_w), mode='bilinear') | |||||
| posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |||||
| posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |||||
| return posemb | |||||
| def forward_flex(self, x): | |||||
| b, c, h, w = x.shape | |||||
| pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], | |||||
| w // self.patch_size[0]) | |||||
| B = x.shape[0] | |||||
| if hasattr(self.patch_embed, 'backbone'): | |||||
| x = self.patch_embed.backbone(x) | |||||
| if isinstance(x, (list, tuple)): | |||||
| x = x[ | |||||
| -1] # last feature if backbone outputs list/tuple of features | |||||
| x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) | |||||
| if getattr(self, 'dist_token', None) is not None: | |||||
| cls_tokens = self.cls_token.expand( | |||||
| B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||||
| dist_token = self.dist_token.expand(B, -1, -1) | |||||
| x = torch.cat((cls_tokens, dist_token, x), dim=1) | |||||
| else: | |||||
| cls_tokens = self.cls_token.expand( | |||||
| B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||||
| x = torch.cat((cls_tokens, x), dim=1) | |||||
| x = x + pos_embed | |||||
| x = self.pos_drop(x) | |||||
| gradient_checkpoint = False | |||||
| for blk in self.blocks: | |||||
| if gradient_checkpoint: | |||||
| x = checkpoint.checkpoint(blk, x) | |||||
| else: | |||||
| x = blk(x) | |||||
| x = self.norm(x) | |||||
| return x | |||||
| def get_readout_oper(vit_features, features, use_readout, start_index=1): | |||||
| if use_readout == 'ignore': | |||||
| readout_oper = [Slice(start_index)] * len(features) | |||||
| elif use_readout == 'add': | |||||
| readout_oper = [AddReadout(start_index)] * len(features) | |||||
| elif use_readout == 'project': | |||||
| readout_oper = [ | |||||
| ProjectReadout(vit_features, start_index) for out_feat in features | |||||
| ] | |||||
| else: | |||||
| assert ( | |||||
| False | |||||
| ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" | |||||
| return readout_oper | |||||
| def adapt_input_conv(in_chans, conv_weight): | |||||
| conv_type = conv_weight.dtype | |||||
| conv_weight = conv_weight.float( | |||||
| ) # Some weights are in torch.half, ensure it's float for sum on CPU | |||||
| O, II, J, K = conv_weight.shape | |||||
| if in_chans == 1: | |||||
| if II > 3: | |||||
| assert conv_weight.shape[1] % 3 == 0 | |||||
| # For models with space2depth stems | |||||
| conv_weight = conv_weight.reshape(O, II // 3, 3, J, K) | |||||
| conv_weight = conv_weight.sum(dim=2, keepdim=False) | |||||
| else: | |||||
| conv_weight = conv_weight.sum(dim=1, keepdim=True) | |||||
| elif in_chans != 3: | |||||
| if II != 3: | |||||
| raise NotImplementedError( | |||||
| 'Weight format not supported by conversion.') | |||||
| else: | |||||
| # NOTE this strategy should be better than random init, but there could be other combinations of | |||||
| # the original RGB input layer weights that'd work better for specific cases. | |||||
| repeat = int(math.ceil(in_chans / 3)) | |||||
| conv_weight = conv_weight.repeat(1, repeat, 1, | |||||
| 1)[:, :in_chans, :, :] | |||||
| conv_weight *= (3 / float(in_chans)) | |||||
| conv_weight = conv_weight.to(conv_type) | |||||
| return conv_weight | |||||
| @torch.no_grad() | |||||
| def _load_weights(model, checkpoint_path, prefix=''): | |||||
| """ Load weights from .npz checkpoints for official Google Brain Flax implementation | |||||
| """ | |||||
| import numpy as np | |||||
| def _n2p(w, t=True): | |||||
| if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: | |||||
| w = w.flatten() | |||||
| if t: | |||||
| if w.ndim == 4: | |||||
| w = w.transpose([3, 2, 0, 1]) | |||||
| elif w.ndim == 3: | |||||
| w = w.transpose([2, 0, 1]) | |||||
| elif w.ndim == 2: | |||||
| w = w.transpose([1, 0]) | |||||
| return torch.from_numpy(w) | |||||
| w = np.load(checkpoint_path) | |||||
| if not prefix and 'opt/target/embedding/kernel' in w: | |||||
| prefix = 'opt/target/' | |||||
| if hasattr(model.patch_embed, 'backbone'): | |||||
| # hybrid | |||||
| backbone = model.patch_embed.backbone | |||||
| stem_only = not hasattr(backbone, 'stem') | |||||
| stem = backbone if stem_only else backbone.stem | |||||
| stem.conv.weight.copy_( | |||||
| adapt_input_conv(stem.conv.weight.shape[1], | |||||
| _n2p(w[f'{prefix}conv_root/kernel']))) | |||||
| stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) | |||||
| stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) | |||||
| if not stem_only: | |||||
| for i, stage in enumerate(backbone.stages): | |||||
| for j, block in enumerate(stage.blocks): | |||||
| bp = f'{prefix}block{i + 1}/unit{j + 1}/' | |||||
| for r in range(3): | |||||
| getattr(block, f'conv{r + 1}').weight.copy_( | |||||
| _n2p(w[f'{bp}conv{r + 1}/kernel'])) | |||||
| getattr(block, f'norm{r + 1}').weight.copy_( | |||||
| _n2p(w[f'{bp}gn{r + 1}/scale'])) | |||||
| getattr(block, f'norm{r + 1}').bias.copy_( | |||||
| _n2p(w[f'{bp}gn{r + 1}/bias'])) | |||||
| if block.downsample is not None: | |||||
| block.downsample.conv.weight.copy_( | |||||
| _n2p(w[f'{bp}conv_proj/kernel'])) | |||||
| block.downsample.norm.weight.copy_( | |||||
| _n2p(w[f'{bp}gn_proj/scale'])) | |||||
| block.downsample.norm.bias.copy_( | |||||
| _n2p(w[f'{bp}gn_proj/bias'])) | |||||
| embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) | |||||
| else: | |||||
| embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], | |||||
| _n2p(w[f'{prefix}embedding/kernel'])) | |||||
| model.patch_embed.proj.weight.copy_(embed_conv_w) | |||||
| model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) | |||||
| model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) | |||||
| pos_embed_w = _n2p( | |||||
| w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) | |||||
| if pos_embed_w.shape != model.pos_embed.shape: | |||||
| pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights | |||||
| pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', | |||||
| 1), | |||||
| model.patch_embed.grid_size) | |||||
| model.pos_embed.copy_(pos_embed_w) | |||||
| model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) | |||||
| model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) | |||||
| if isinstance( | |||||
| model.head, nn.Linear | |||||
| ) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: | |||||
| model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) | |||||
| model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) | |||||
| # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights | |||||
| # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: | |||||
| # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) | |||||
| # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) | |||||
| for i, block in enumerate(model.blocks.children()): | |||||
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' | |||||
| mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' | |||||
| block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) | |||||
| block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) | |||||
| block.attn.qkv.weight.copy_( | |||||
| torch.cat([ | |||||
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T | |||||
| for n in ('query', 'key', 'value') | |||||
| ])) | |||||
| block.attn.qkv.bias.copy_( | |||||
| torch.cat([ | |||||
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) | |||||
| for n in ('query', 'key', 'value') | |||||
| ])) | |||||
| block.attn.proj.weight.copy_( | |||||
| _n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) | |||||
| block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) | |||||
| for r in range(2): | |||||
| getattr(block.mlp, f'fc{r + 1}').weight.copy_( | |||||
| _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) | |||||
| getattr(block.mlp, f'fc{r + 1}').bias.copy_( | |||||
| _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) | |||||
| block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) | |||||
| block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) | |||||
| def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): | |||||
| # Rescale the grid of position embeddings when loading from state_dict. Adapted from | |||||
| # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 | |||||
| ntok_new = posemb_new.shape[1] | |||||
| if num_prefix_tokens: | |||||
| posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[ | |||||
| 0, num_prefix_tokens:] | |||||
| ntok_new -= num_prefix_tokens | |||||
| else: | |||||
| posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] | |||||
| gs_old = int(math.sqrt(len(posemb_grid))) | |||||
| if not len(gs_new): # backwards compatibility | |||||
| gs_new = [int(math.sqrt(ntok_new))] * 2 | |||||
| assert len(gs_new) >= 2 | |||||
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, | |||||
| -1).permute(0, 3, 1, 2) | |||||
| posemb_grid = F.interpolate( | |||||
| posemb_grid, size=gs_new, mode='bicubic', align_corners=False) | |||||
| posemb_grid = posemb_grid.permute(0, 2, 3, | |||||
| 1).reshape(1, gs_new[0] * gs_new[1], -1) | |||||
| posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) | |||||
| return posemb | |||||
| def _make_pretrained_clip_vitl16_384(pretrained, | |||||
| use_readout='ignore', | |||||
| hooks=None, | |||||
| enable_attention_hooks=False): | |||||
| clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False) | |||||
| # model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) | |||||
| model = timm.create_model('vit_large_patch16_384', pretrained=False) | |||||
| hooks = [5, 11, 17, 23] if hooks is None else hooks | |||||
| pretrained = _make_vit_b16_backbone( | |||||
| model, | |||||
| features=[256, 512, 1024, 1024], | |||||
| hooks=hooks, | |||||
| vit_features=1024, | |||||
| use_readout=use_readout, | |||||
| enable_attention_hooks=enable_attention_hooks, | |||||
| ) | |||||
| return clip_pretrained, pretrained | |||||
| def _make_vit_b16_backbone( | |||||
| model, | |||||
| features=[96, 192, 384, 768], | |||||
| size=[384, 384], | |||||
| hooks=[2, 5, 8, 11], | |||||
| vit_features=768, | |||||
| use_readout='ignore', | |||||
| start_index=1, | |||||
| enable_attention_hooks=False, | |||||
| ): | |||||
| pretrained = nn.Module() | |||||
| pretrained.model = model | |||||
| pretrained.model.blocks[hooks[0]].register_forward_hook( | |||||
| get_activation('1')) | |||||
| pretrained.model.blocks[hooks[1]].register_forward_hook( | |||||
| get_activation('2')) | |||||
| pretrained.model.blocks[hooks[2]].register_forward_hook( | |||||
| get_activation('3')) | |||||
| pretrained.model.blocks[hooks[3]].register_forward_hook( | |||||
| get_activation('4')) | |||||
| pretrained.activations = activations | |||||
| if enable_attention_hooks: | |||||
| pretrained.model.blocks[hooks[0]].attn.register_forward_hook( | |||||
| get_attention('attn_1')) | |||||
| pretrained.model.blocks[hooks[1]].attn.register_forward_hook( | |||||
| get_attention('attn_2')) | |||||
| pretrained.model.blocks[hooks[2]].attn.register_forward_hook( | |||||
| get_attention('attn_3')) | |||||
| pretrained.model.blocks[hooks[3]].attn.register_forward_hook( | |||||
| get_attention('attn_4')) | |||||
| pretrained.attention = attention | |||||
| readout_oper = get_readout_oper(vit_features, features, use_readout, | |||||
| start_index) | |||||
| # 32, 48, 136, 384 | |||||
| pretrained.act_postprocess1 = nn.Sequential( | |||||
| readout_oper[0], | |||||
| Transpose(1, 2), | |||||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||||
| nn.Conv2d( | |||||
| in_channels=vit_features, | |||||
| out_channels=features[0], | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| ), | |||||
| nn.ConvTranspose2d( | |||||
| in_channels=features[0], | |||||
| out_channels=features[0], | |||||
| kernel_size=4, | |||||
| stride=4, | |||||
| padding=0, | |||||
| bias=True, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| ), | |||||
| ) | |||||
| pretrained.act_postprocess2 = nn.Sequential( | |||||
| readout_oper[1], | |||||
| Transpose(1, 2), | |||||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||||
| nn.Conv2d( | |||||
| in_channels=vit_features, | |||||
| out_channels=features[1], | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| ), | |||||
| nn.ConvTranspose2d( | |||||
| in_channels=features[1], | |||||
| out_channels=features[1], | |||||
| kernel_size=2, | |||||
| stride=2, | |||||
| padding=0, | |||||
| bias=True, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| ), | |||||
| ) | |||||
| pretrained.act_postprocess3 = nn.Sequential( | |||||
| readout_oper[2], | |||||
| Transpose(1, 2), | |||||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||||
| nn.Conv2d( | |||||
| in_channels=vit_features, | |||||
| out_channels=features[2], | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| ), | |||||
| ) | |||||
| pretrained.act_postprocess4 = nn.Sequential( | |||||
| readout_oper[3], | |||||
| Transpose(1, 2), | |||||
| nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), | |||||
| nn.Conv2d( | |||||
| in_channels=vit_features, | |||||
| out_channels=features[3], | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| ), | |||||
| nn.Conv2d( | |||||
| in_channels=features[3], | |||||
| out_channels=features[3], | |||||
| kernel_size=3, | |||||
| stride=2, | |||||
| padding=1, | |||||
| ), | |||||
| ) | |||||
| pretrained.model.start_index = start_index | |||||
| pretrained.model.patch_size = [16, 16] | |||||
| # We inject this function into the VisionTransformer instances so that | |||||
| # we can use it with interpolated position embeddings without modifying the library source. | |||||
| pretrained.model.forward_flex = types.MethodType(forward_flex, | |||||
| pretrained.model) | |||||
| pretrained.model._resize_pos_embed = types.MethodType( | |||||
| _resize_pos_embed, pretrained.model) | |||||
| return pretrained | |||||
| @@ -0,0 +1,458 @@ | |||||
| """ | |||||
| Adapted from https://github.com/isl-org/lang-seg. | |||||
| Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. | |||||
| """ | |||||
| from collections import OrderedDict | |||||
| from typing import Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1): | |||||
| super().__init__() | |||||
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(planes) | |||||
| self.relu1 = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(planes) | |||||
| self.relu2 = nn.ReLU(inplace=True) | |||||
| self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() | |||||
| self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||||
| self.relu3 = nn.ReLU(inplace=True) | |||||
| self.downsample = None | |||||
| self.stride = stride | |||||
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |||||
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |||||
| self.downsample = nn.Sequential( | |||||
| OrderedDict([('-1', nn.AvgPool2d(stride)), | |||||
| ('0', | |||||
| nn.Conv2d( | |||||
| inplanes, | |||||
| planes * self.expansion, | |||||
| 1, | |||||
| stride=1, | |||||
| bias=False)), | |||||
| ('1', nn.BatchNorm2d(planes * self.expansion))])) | |||||
| def forward(self, x: torch.Tensor): | |||||
| identity = x | |||||
| out = self.relu1(self.bn1(self.conv1(x))) | |||||
| out = self.relu2(self.bn2(self.conv2(out))) | |||||
| out = self.avgpool(out) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu3(out) | |||||
| return out | |||||
| class AttentionPool2d(nn.Module): | |||||
| def __init__(self, | |||||
| spacial_dim: int, | |||||
| embed_dim: int, | |||||
| num_heads: int, | |||||
| output_dim: int = None): | |||||
| super().__init__() | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) | |||||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |||||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |||||
| self.num_heads = num_heads | |||||
| def forward(self, x): | |||||
| x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC | |||||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |||||
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |||||
| x, _ = F.multi_head_attention_forward( | |||||
| query=x[:1], | |||||
| key=x, | |||||
| value=x, | |||||
| embed_dim_to_check=x.shape[-1], | |||||
| num_heads=self.num_heads, | |||||
| q_proj_weight=self.q_proj.weight, | |||||
| k_proj_weight=self.k_proj.weight, | |||||
| v_proj_weight=self.v_proj.weight, | |||||
| in_proj_weight=None, | |||||
| in_proj_bias=torch.cat( | |||||
| [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |||||
| bias_k=None, | |||||
| bias_v=None, | |||||
| add_zero_attn=False, | |||||
| dropout_p=0, | |||||
| out_proj_weight=self.c_proj.weight, | |||||
| out_proj_bias=self.c_proj.bias, | |||||
| use_separate_proj_weight=True, | |||||
| training=self.training, | |||||
| need_weights=False) | |||||
| return x.squeeze(0) | |||||
| class ModifiedResNet(nn.Module): | |||||
| """ | |||||
| A ResNet class that is similar to torchvision's but contains the following changes: | |||||
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |||||
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |||||
| - The final pooling layer is a QKV attention instead of an average pool | |||||
| """ | |||||
| def __init__(self, | |||||
| layers, | |||||
| output_dim, | |||||
| heads, | |||||
| input_resolution=224, | |||||
| width=64): | |||||
| super().__init__() | |||||
| self.output_dim = output_dim | |||||
| self.input_resolution = input_resolution | |||||
| # the 3-layer stem | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(width // 2) | |||||
| self.relu1 = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d( | |||||
| width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |||||
| self.bn2 = nn.BatchNorm2d(width // 2) | |||||
| self.relu2 = nn.ReLU(inplace=True) | |||||
| self.conv3 = nn.Conv2d( | |||||
| width // 2, width, kernel_size=3, padding=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(width) | |||||
| self.relu3 = nn.ReLU(inplace=True) | |||||
| self.avgpool = nn.AvgPool2d(2) | |||||
| # residual layers | |||||
| self._inplanes = width # this is a *mutable* variable used during construction | |||||
| self.layer1 = self._make_layer(width, layers[0]) | |||||
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |||||
| embed_dim = width * 32 # the ResNet feature dimension | |||||
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, | |||||
| heads, output_dim) | |||||
| def _make_layer(self, planes, blocks, stride=1): | |||||
| layers = [Bottleneck(self._inplanes, planes, stride)] | |||||
| self._inplanes = planes * Bottleneck.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append(Bottleneck(self._inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| def stem(x): | |||||
| x = self.relu1(self.bn1(self.conv1(x))) | |||||
| x = self.relu2(self.bn2(self.conv2(x))) | |||||
| x = self.relu3(self.bn3(self.conv3(x))) | |||||
| x = self.avgpool(x) | |||||
| return x | |||||
| x = x.type(self.conv1.weight.dtype) | |||||
| x = stem(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.attnpool(x) | |||||
| return x | |||||
| class LayerNorm(nn.LayerNorm): | |||||
| """Subclass torch's LayerNorm to handle fp16.""" | |||||
| def forward(self, x: torch.Tensor): | |||||
| orig_type = x.dtype | |||||
| ret = super().forward(x.type(torch.float32)) | |||||
| return ret.type(orig_type) | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x: torch.Tensor): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class ResidualAttentionBlock(nn.Module): | |||||
| def __init__(self, | |||||
| d_model: int, | |||||
| n_head: int, | |||||
| attn_mask: torch.Tensor = None): | |||||
| super().__init__() | |||||
| self.attn = nn.MultiheadAttention(d_model, n_head) | |||||
| self.ln_1 = LayerNorm(d_model) | |||||
| self.mlp = nn.Sequential( | |||||
| OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), | |||||
| ('gelu', QuickGELU()), | |||||
| ('c_proj', nn.Linear(d_model * 4, d_model))])) | |||||
| self.ln_2 = LayerNorm(d_model) | |||||
| self.attn_mask = attn_mask | |||||
| def attention(self, x: torch.Tensor): | |||||
| self.attn_mask = self.attn_mask.to( | |||||
| dtype=x.dtype, | |||||
| device=x.device) if self.attn_mask is not None else None | |||||
| return self.attn( | |||||
| x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |||||
| def forward(self, x: torch.Tensor): | |||||
| x = x + self.attention(self.ln_1(x)) | |||||
| x = x + self.mlp(self.ln_2(x)) | |||||
| return x | |||||
| class Transformer(nn.Module): | |||||
| def __init__(self, width, layers, heads, attn_mask=None): | |||||
| super().__init__() | |||||
| self.width = width | |||||
| self.layers = layers | |||||
| self.resblocks = nn.Sequential(*[ | |||||
| ResidualAttentionBlock(width, heads, attn_mask) | |||||
| for _ in range(layers) | |||||
| ]) | |||||
| def forward(self, x: torch.Tensor): | |||||
| return self.resblocks(x) | |||||
| class VisionTransformer(nn.Module): | |||||
| def __init__(self, input_resolution: int, patch_size: int, width: int, | |||||
| layers: int, heads: int, output_dim: int): | |||||
| super().__init__() | |||||
| self.input_resolution = input_resolution | |||||
| self.output_dim = output_dim | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_channels=3, | |||||
| out_channels=width, | |||||
| kernel_size=patch_size, | |||||
| stride=patch_size, | |||||
| bias=False) | |||||
| scale = width**-0.5 | |||||
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |||||
| self.positional_embedding = nn.Parameter(scale * torch.randn( | |||||
| (input_resolution // patch_size)**2 + 1, width)) | |||||
| self.ln_pre = LayerNorm(width) | |||||
| self.transformer = Transformer(width, layers, heads) | |||||
| self.ln_post = LayerNorm(width) | |||||
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |||||
| def forward(self, x: torch.Tensor): | |||||
| x = self.conv1(x) # shape = [*, width, grid, grid] | |||||
| x = x.reshape(x.shape[0], x.shape[1], | |||||
| -1) # shape = [*, width, grid ** 2] | |||||
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |||||
| x1 = self.class_embedding.to(x.dtype) | |||||
| x2 = torch.zeros( | |||||
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |||||
| x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width] | |||||
| x = x + self.positional_embedding.to(x.dtype) | |||||
| x = self.ln_pre(x) | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) # LND -> NLD | |||||
| x = self.ln_post(x[:, 0, :]) | |||||
| if self.proj is not None: | |||||
| x = x @ self.proj | |||||
| return x | |||||
| class CLIP(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| embed_dim: int, | |||||
| # vision | |||||
| image_resolution: int, | |||||
| vision_layers: Union[Tuple[int, int, int, int], int], | |||||
| vision_width: int, | |||||
| vision_patch_size: int, | |||||
| # text | |||||
| context_length: int, | |||||
| vocab_size: int, | |||||
| transformer_width: int, | |||||
| transformer_heads: int, | |||||
| transformer_layers: int): | |||||
| super().__init__() | |||||
| self.context_length = context_length | |||||
| if isinstance(vision_layers, (tuple, list)): | |||||
| vision_heads = vision_width * 32 // 64 | |||||
| self.visual = ModifiedResNet( | |||||
| layers=vision_layers, | |||||
| output_dim=embed_dim, | |||||
| heads=vision_heads, | |||||
| input_resolution=image_resolution, | |||||
| width=vision_width) | |||||
| else: | |||||
| vision_heads = vision_width // 64 | |||||
| self.visual = VisionTransformer( | |||||
| input_resolution=image_resolution, | |||||
| patch_size=vision_patch_size, | |||||
| width=vision_width, | |||||
| layers=vision_layers, | |||||
| heads=vision_heads, | |||||
| output_dim=embed_dim) | |||||
| self.transformer = Transformer( | |||||
| width=transformer_width, | |||||
| layers=transformer_layers, | |||||
| heads=transformer_heads, | |||||
| attn_mask=self.build_attention_mask()) | |||||
| self.vocab_size = vocab_size | |||||
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |||||
| self.positional_embedding = nn.Parameter( | |||||
| torch.empty(self.context_length, transformer_width)) | |||||
| self.ln_final = LayerNorm(transformer_width) | |||||
| self.text_projection = nn.Parameter( | |||||
| torch.empty(transformer_width, embed_dim)) | |||||
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |||||
| self.initialize_parameters() | |||||
| def initialize_parameters(self): | |||||
| nn.init.normal_(self.token_embedding.weight, std=0.02) | |||||
| nn.init.normal_(self.positional_embedding, std=0.01) | |||||
| if isinstance(self.visual, ModifiedResNet): | |||||
| if self.visual.attnpool is not None: | |||||
| std = self.visual.attnpool.c_proj.in_features**-0.5 | |||||
| nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) | |||||
| nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) | |||||
| for resnet_block in [ | |||||
| self.visual.layer1, self.visual.layer2, self.visual.layer3, | |||||
| self.visual.layer4 | |||||
| ]: | |||||
| for name, param in resnet_block.named_parameters(): | |||||
| if name.endswith('bn3.weight'): | |||||
| nn.init.zeros_(param) | |||||
| proj_std = (self.transformer.width**-0.5) * ( | |||||
| (2 * self.transformer.layers)**-0.5) | |||||
| attn_std = self.transformer.width**-0.5 | |||||
| fc_std = (2 * self.transformer.width)**-0.5 | |||||
| for block in self.transformer.resblocks: | |||||
| nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |||||
| nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |||||
| nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |||||
| nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |||||
| if self.text_projection is not None: | |||||
| nn.init.normal_( | |||||
| self.text_projection, std=self.transformer.width**-0.5) | |||||
| def build_attention_mask(self): | |||||
| # lazily create causal attention mask, with full attention between the vision tokens | |||||
| # pytorch uses additive attention mask; fill with -inf | |||||
| mask = torch.empty(self.context_length, self.context_length) | |||||
| mask.fill_(float('-inf')) | |||||
| mask.triu_(1) # zero out the lower diagonal | |||||
| return mask | |||||
| @property | |||||
| def dtype(self): | |||||
| return self.visual.conv1.weight.dtype | |||||
| def encode_image(self, image): | |||||
| return self.visual(image.type(self.dtype)) | |||||
| def encode_text(self, text): | |||||
| x = self.token_embedding(text).type(self.dtype) | |||||
| x = x + self.positional_embedding.type(self.dtype) | |||||
| x = x.permute(1, 0, 2) # NLD -> LND | |||||
| x = self.transformer(x) | |||||
| x = x.permute(1, 0, 2) # LND -> NLD | |||||
| x = self.ln_final(x).type(self.dtype) | |||||
| x = x[torch.arange(x.shape[0]), | |||||
| text.argmax(dim=-1)] @ self.text_projection | |||||
| return x | |||||
| def forward(self, image, text): | |||||
| image_features = self.encode_image(image) | |||||
| text_features = self.encode_text(text) | |||||
| # normalized features | |||||
| image_features = image_features / image_features.norm( | |||||
| dim=1, keepdim=True) | |||||
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |||||
| # cosine similarity as logits | |||||
| logit_scale = self.logit_scale.exp() | |||||
| logits_per_image = logit_scale * image_features @ text_features.t() | |||||
| logits_per_text = logits_per_image.t() | |||||
| # shape = [global_batch_size, global_batch_size] | |||||
| return logits_per_image, logits_per_text | |||||
| def convert_weights(model: nn.Module): | |||||
| """Convert applicable model parameters to fp16""" | |||||
| def _convert_weights_to_fp16(ll): | |||||
| if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||||
| ll.weight.data = ll.weight.data.half() | |||||
| if ll.bias is not None: | |||||
| ll.bias.data = ll.bias.data.half() | |||||
| if isinstance(ll, nn.MultiheadAttention): | |||||
| for attr in [ | |||||
| *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], | |||||
| 'in_proj_bias', 'bias_k', 'bias_v' | |||||
| ]: | |||||
| tensor = getattr(ll, attr) | |||||
| if tensor is not None: | |||||
| tensor.data = tensor.data.half() | |||||
| for name in ['text_projection', 'proj']: | |||||
| if hasattr(ll, name): | |||||
| attr = getattr(ll, name) | |||||
| if attr is not None: | |||||
| attr.data = attr.data.half() | |||||
| model.apply(_convert_weights_to_fp16) | |||||
| def build_model(): | |||||
| model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12) | |||||
| convert_weights(model) | |||||
| return model.eval() | |||||
| @@ -0,0 +1,156 @@ | |||||
| """ CLIP | |||||
| Adapted from https://github.com/openai/CLIP. | |||||
| Originally MIT License, Copyright (c) 2021 OpenAI. | |||||
| """ | |||||
| import gzip | |||||
| import html | |||||
| import os | |||||
| from functools import lru_cache | |||||
| import ftfy | |||||
| import regex as re | |||||
| @lru_cache() | |||||
| def default_bpe(): | |||||
| return os.path.join( | |||||
| os.path.dirname(os.path.abspath(__file__)), | |||||
| 'bpe_simple_vocab_16e6.txt.gz') | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| def basic_clean(text): | |||||
| text = ftfy.fix_text(text) | |||||
| text = html.unescape(html.unescape(text)) | |||||
| return text.strip() | |||||
| def whitespace_clean(text): | |||||
| text = re.sub(r'\s+', ' ', text) | |||||
| text = text.strip() | |||||
| return text | |||||
| class SimpleTokenizer(object): | |||||
| def __init__(self, bpe_path: str = default_bpe()): | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') | |||||
| merges = merges[1:49152 - 256 - 2 + 1] | |||||
| merges = [tuple(merge.split()) for merge in merges] | |||||
| vocab = list(bytes_to_unicode().values()) | |||||
| vocab = vocab + [v + '</w>' for v in vocab] | |||||
| for merge in merges: | |||||
| vocab.append(''.join(merge)) | |||||
| vocab.extend(['<|startoftext|>', '<|endoftext|>']) | |||||
| self.encoder = dict(zip(vocab, range(len(vocab)))) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.bpe_ranks = dict(zip(merges, range(len(merges)))) | |||||
| self.cache = { | |||||
| '<|startoftext|>': '<|startoftext|>', | |||||
| '<|endoftext|>': '<|endoftext|>' | |||||
| } | |||||
| self.pat = re.compile( | |||||
| r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", | |||||
| re.IGNORECASE) | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token[:-1]) + (token[-1] + '</w>', ) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token + '</w>' | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| error_list = [] | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except Exception as err: | |||||
| new_word.extend(word[i:]) | |||||
| error_list.append(err) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def encode(self, text): | |||||
| bpe_tokens = [] | |||||
| text = whitespace_clean(basic_clean(text)).lower() | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| bpe_tokens.extend(self.encoder[bpe_token] | |||||
| for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors='replace').replace('</w>', ' ') | |||||
| return text | |||||
| @@ -243,6 +243,13 @@ TASK_OUTPUTS = { | |||||
| # "output_img": np.ndarray with shape [height, width, 3] | # "output_img": np.ndarray with shape [height, width, 3] | ||||
| # } | # } | ||||
| Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], | Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], | ||||
| # text driven segmentation result for single sample | |||||
| # { | |||||
| # "masks": [ | |||||
| # np.array # 2D array containing only 0, 255 | |||||
| # ] | |||||
| # } | |||||
| Tasks.text_driven_segmentation: [OutputKeys.MASKS], | |||||
| # movide scene segmentation result for a single video | # movide scene segmentation result for a single video | ||||
| # { | # { | ||||
| @@ -149,6 +149,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_vitb_video-single-object-tracking_ostrack'), | 'damo/cv_vitb_video-single-object-tracking_ostrack'), | ||||
| Tasks.image_reid_person: (Pipelines.image_reid_person, | Tasks.image_reid_person: (Pipelines.image_reid_person, | ||||
| 'damo/cv_passvitb_image-reid-person_market'), | 'damo/cv_passvitb_image-reid-person_market'), | ||||
| Tasks.text_driven_segmentation: | |||||
| (Pipelines.text_driven_segmentation, | |||||
| 'damo/cv_vitl16_segmentation_text-driven-seg'), | |||||
| Tasks.movie_scene_segmentation: | Tasks.movie_scene_segmentation: | ||||
| (Pipelines.movie_scene_segmentation, | (Pipelines.movie_scene_segmentation, | ||||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | 'damo/cv_resnet50-bert_video-scene-segmentation_movienet') | ||||
| @@ -44,6 +44,7 @@ if TYPE_CHECKING: | |||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | from .virtual_try_on_pipeline import VirtualTryonPipeline | ||||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | ||||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | |||||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | ||||
| else: | else: | ||||
| @@ -97,6 +98,8 @@ else: | |||||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | ||||
| 'easycv_pipeline': | 'easycv_pipeline': | ||||
| ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], | ['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'], | ||||
| 'text_driven_segmentation_pipeline': | |||||
| ['TextDrivenSegmentationPipeline'], | |||||
| 'movie_scene_segmentation_pipeline': | 'movie_scene_segmentation_pipeline': | ||||
| ['MovieSceneSegmentationPipeline'], | ['MovieSceneSegmentationPipeline'], | ||||
| } | } | ||||
| @@ -0,0 +1,51 @@ | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text_driven_segmentation, | |||||
| module_name=Pipelines.text_driven_segmentation) | |||||
| class TextDrivenSegmentationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||||
| def preprocess(self, input: Dict) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input['image']) | |||||
| img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) | |||||
| result = { | |||||
| 'img': img_tensor, | |||||
| 'ori_h': ori_h, | |||||
| 'ori_w': ori_w, | |||||
| 'crop_h': crop_h, | |||||
| 'crop_w': crop_w, | |||||
| 'text': input['text'], | |||||
| } | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| outputs = self.model.inference(input['img'], input['text']) | |||||
| result = { | |||||
| 'data': outputs, | |||||
| 'ori_h': input['ori_h'], | |||||
| 'ori_w': input['ori_w'], | |||||
| 'crop_h': input['crop_h'], | |||||
| 'crop_w': input['crop_w'], | |||||
| } | |||||
| return result | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| data = self.model.postprocess(inputs['data'], inputs['crop_h'], | |||||
| inputs['crop_w'], inputs['ori_h'], | |||||
| inputs['ori_w']) | |||||
| outputs = {OutputKeys.MASKS: data} | |||||
| return outputs | |||||
| @@ -36,6 +36,7 @@ class CVTasks(object): | |||||
| image_segmentation = 'image-segmentation' | image_segmentation = 'image-segmentation' | ||||
| portrait_matting = 'portrait-matting' | portrait_matting = 'portrait-matting' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | |||||
| # image editing | # image editing | ||||
| skin_retouching = 'skin-retouching' | skin_retouching = 'skin-retouching' | ||||
| @@ -0,0 +1,28 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TextDrivenSegmentationTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_text_driven_segmentation(self): | |||||
| input_location = 'data/test/images/text_driven_segmentation.jpg' | |||||
| test_input = { | |||||
| 'image': input_location, | |||||
| 'text': 'bear', | |||||
| } | |||||
| model_id = 'damo/cv_vitl16_segmentation_text-driven-seg' | |||||
| shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id) | |||||
| result = shop_seg(test_input) | |||||
| import cv2 | |||||
| # result[OutputKeys.MASKS] is segment map result,other keys are not used | |||||
| cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||