From cb570d586cb5f4a467de9aad1e058e3cd3276518 Mon Sep 17 00:00:00 2001 From: "shuying.shu" Date: Tue, 18 Oct 2022 16:10:10 +0800 Subject: [PATCH] add referring video object segmentation pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10400324 --- ...g_video_object_segmentation_test_video.mp4 | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 3 +- .../__init__.py | 23 + .../model.py | 65 ++ .../utils/__init__.py | 4 + .../utils/backbone.py | 198 +++++ .../utils/misc.py | 234 ++++++ .../utils/mttr.py | 128 +++ .../utils/multimodal_transformer.py | 440 +++++++++++ .../utils/position_encoding_2d.py | 57 ++ .../utils/postprocessing.py | 119 +++ .../utils/segmentation.py | 137 ++++ .../utils/swin_transformer.py | 731 ++++++++++++++++++ modelscope/outputs.py | 6 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 4 + ...ring_video_object_segmentation_pipeline.py | 193 +++++ modelscope/utils/constant.py | 3 + requirements/cv.txt | 2 + ...est_referring_video_object_segmentation.py | 56 ++ 21 files changed, 2410 insertions(+), 1 deletion(-) create mode 100644 data/test/videos/referring_video_object_segmentation_test_video.mp4 create mode 100644 modelscope/models/cv/referring_video_object_segmentation/__init__.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/model.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/misc.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py create mode 100644 modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py create mode 100644 modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py create mode 100644 tests/pipelines/test_referring_video_object_segmentation.py diff --git a/data/test/videos/referring_video_object_segmentation_test_video.mp4 b/data/test/videos/referring_video_object_segmentation_test_video.mp4 new file mode 100644 index 00000000..529595a5 --- /dev/null +++ b/data/test/videos/referring_video_object_segmentation_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb +size 2924691 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2dbff948..fc18ead9 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -34,6 +34,7 @@ class Models(object): vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' text_driven_segmentation = 'text-driven-segmentation' resnet50_bert = 'resnet50-bert' + referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' fer = 'fer' retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' @@ -203,6 +204,7 @@ class Pipelines(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' image_body_reshaping = 'flow-based-body-reshaping' + referring_video_object_segmentation = 'referring-video-object-segmentation' # nlp tasks automatic_post_editing = 'automatic-post-editing' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index fd950f4c..64039863 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -12,7 +12,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_to_image_generation, image_to_image_translation, movie_scene_segmentation, object_detection, product_retrieval_embedding, realtime_object_detection, - salient_detection, shop_segmentation, super_resolution, + referring_video_object_segmentation, salient_detection, + shop_segmentation, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/referring_video_object_segmentation/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..58dbf7b0 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import MovieSceneSegmentation + +else: + _import_structure = { + 'model': ['MovieSceneSegmentation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py new file mode 100644 index 00000000..902a3416 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, + nested_tensor_from_videos_list) + +logger = get_logger() + + +@MODELS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentation(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + self.model = MTTR(**self.cfg.model) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params_dict = torch.load(model_path, map_location='cpu') + if 'model_state_dict' in params_dict.keys(): + params_dict = params_dict['model_state_dict'] + self.model.load_state_dict(params_dict, strict=True) + + dataset_name = self.cfg.pipeline.dataset_name + if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': + self.postprocessor = A2DSentencesPostProcess() + elif dataset_name == 'ref_youtube_vos': + self.postprocessor = ReferYoutubeVOSPostProcess() + else: + assert False, f'postprocessing for dataset: {dataset_name} is not supported' + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + return inputs + + def inference(self, **kwargs): + window = kwargs['window'] + text_query = kwargs['text_query'] + video_metadata = kwargs['metadata'] + + window = nested_tensor_from_videos_list([window]) + valid_indices = torch.arange(len(window.tensors)) + if self._device_name == 'gpu': + valid_indices = valid_indices.cuda() + outputs = self.model(window, valid_indices, [text_query]) + window_masks = self.postprocessor( + outputs, [video_metadata], + window.tensors.shape[-2:])[0]['pred_masks'] + return window_masks + + def postprocess(self, inputs: Dict[str, Any], **kwargs): + return inputs diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py new file mode 100644 index 00000000..796bd6f4 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .misc import nested_tensor_from_videos_list +from .mttr import MTTR +from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py new file mode 100644 index 00000000..afa384c1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +import torchvision +from einops import rearrange +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from .misc import NestedTensor, is_main_process +from .swin_transformer import SwinTransformer3D + + +class VideoSwinTransformerBackbone(nn.Module): + """ + A wrapper which allows using Video-Swin Transformer as a temporal encoder for MTTR. + Check out video-swin's original paper at: https://arxiv.org/abs/2106.13230 for more info about this architecture. + Only the 'tiny' version of video swin was tested and is currently supported in our project. + Additionally, we slightly modify video-swin to make it output per-frame embeddings as required by MTTR (check our + paper's supplementary for more details), and completely discard of its 4th block. + """ + + def __init__(self, backbone_pretrained, backbone_pretrained_path, + train_backbone, running_mode, **kwargs): + super(VideoSwinTransformerBackbone, self).__init__() + # patch_size is (1, 4, 4) instead of the original (2, 4, 4). + # this prevents swinT's original temporal downsampling so we can get per-frame features. + swin_backbone = SwinTransformer3D( + patch_size=(1, 4, 4), + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=(8, 7, 7), + drop_path_rate=0.1, + patch_norm=True) + if backbone_pretrained and running_mode == 'train': + state_dict = torch.load(backbone_pretrained_path)['state_dict'] + # extract swinT's kinetics-400 pretrained weights and ignore the rest (prediction head etc.) + state_dict = { + k[9:]: v + for k, v in state_dict.items() if 'backbone.' in k + } + + # sum over the patch embedding weight temporal dim [96, 3, 2, 4, 4] --> [96, 3, 1, 4, 4] + patch_embed_weight = state_dict['patch_embed.proj.weight'] + patch_embed_weight = patch_embed_weight.sum(dim=2, keepdims=True) + state_dict['patch_embed.proj.weight'] = patch_embed_weight + swin_backbone.load_state_dict(state_dict) + + self.patch_embed = swin_backbone.patch_embed + self.pos_drop = swin_backbone.pos_drop + self.layers = swin_backbone.layers[:-1] + self.downsamples = nn.ModuleList() + for layer in self.layers: + self.downsamples.append(layer.downsample) + layer.downsample = None + self.downsamples[ + -1] = None # downsampling after the last layer is not necessary + + self.layer_output_channels = [ + swin_backbone.embed_dim * 2**i for i in range(len(self.layers)) + ] + self.train_backbone = train_backbone + if not train_backbone: + for parameter in self.parameters(): + parameter.requires_grad_(False) + + def forward(self, samples: NestedTensor): + vid_frames = rearrange(samples.tensors, 't b c h w -> b c t h w') + + vid_embeds = self.patch_embed(vid_frames) + vid_embeds = self.pos_drop(vid_embeds) + layer_outputs = [] # layer outputs before downsampling + for layer, downsample in zip(self.layers, self.downsamples): + vid_embeds = layer(vid_embeds.contiguous()) + layer_outputs.append(vid_embeds) + if downsample: + vid_embeds = rearrange(vid_embeds, 'b c t h w -> b t h w c') + vid_embeds = downsample(vid_embeds) + vid_embeds = rearrange(vid_embeds, 'b t h w c -> b c t h w') + layer_outputs = [ + rearrange(o, 'b c t h w -> t b c h w') for o in layer_outputs + ] + + outputs = [] + orig_pad_mask = samples.mask + for l_out in layer_outputs: + pad_mask = F.interpolate( + orig_pad_mask.float(), size=l_out.shape[-2:]).to(torch.bool) + outputs.append(NestedTensor(l_out, pad_mask)) + return outputs + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + BatchNorm2d where the batch statistics and the affine parameters are fixed. + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer('weight', torch.ones(n)) + self.register_buffer('bias', torch.zeros(n)) + self.register_buffer('running_mean', torch.zeros(n)) + self.register_buffer('running_var', torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class ResNetBackbone(nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + ResNet backbone with frozen BatchNorm. + """ + + def __init__(self, + backbone_name: str = 'resnet50', + train_backbone: bool = True, + dilation: bool = True, + **kwargs): + super(ResNetBackbone, self).__init__() + backbone = getattr(torchvision.models, backbone_name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=FrozenBatchNorm2d) + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + return_layers = { + 'layer1': '0', + 'layer2': '1', + 'layer3': '2', + 'layer4': '3' + } + self.body = IntermediateLayerGetter( + backbone, return_layers=return_layers) + output_channels = 512 if backbone_name in ('resnet18', + 'resnet34') else 2048 + self.layer_output_channels = [ + output_channels // 8, output_channels // 4, output_channels // 2, + output_channels + ] + + def forward(self, tensor_list: NestedTensor): + t, b, _, _, _ = tensor_list.tensors.shape + video_frames = rearrange(tensor_list.tensors, + 't b c h w -> (t b) c h w') + padding_masks = rearrange(tensor_list.mask, 't b h w -> (t b) h w') + features_list = self.body(video_frames) + out = [] + for _, f in features_list.items(): + resized_padding_masks = F.interpolate( + padding_masks[None].float(), + size=f.shape[-2:]).to(torch.bool)[0] + f = rearrange(f, '(t b) c h w -> t b c h w', t=t, b=b) + resized_padding_masks = rearrange( + resized_padding_masks, '(t b) h w -> t b h w', t=t, b=b) + out.append(NestedTensor(f, resized_padding_masks)) + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def init_backbone(backbone_name, **kwargs): + if backbone_name == 'swin-t': + return VideoSwinTransformerBackbone(**kwargs) + elif 'resnet' in backbone_name: + return ResNetBackbone(backbone_name, **kwargs) + assert False, f'error: backbone "{backbone_name}" is not supported' diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py new file mode 100644 index 00000000..ecf34b8c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py @@ -0,0 +1,234 @@ +# Modified from DETR https://github.com/facebookresearch/detr +# Misc functions. +# Mostly copy-paste from torchvision references. + +import pickle +from typing import List, Optional + +import torch +import torch.distributed as dist +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + +if float(torchvision.__version__.split('.')[1]) < 7.0: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device='cuda') + size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append( + torch.empty((max_size, ), dtype=torch.uint8, device='cuda')) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size, ), dtype=torch.uint8, device='cuda') + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + """ + This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their + padding masks (true for padding areas, false otherwise). + """ + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) + m[:img.shape[1], :img.shape[2]] = False + return NestedTensor(tensor, mask) + + +def nested_tensor_from_videos_list(videos_list: List[Tensor]): + """ + This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded + videos (shape [T, B, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape + [T, B, PH, PW]. + """ + max_size = _max_by_axis([list(img.shape) for img in videos_list]) + padded_batch_shape = [len(videos_list)] + max_size + b, t, c, h, w = padded_batch_shape + dtype = videos_list[0].dtype + device = videos_list[0].device + padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) + videos_pad_masks = torch.ones((b, t, h, w), + dtype=torch.bool, + device=device) + for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, + padded_videos, + videos_pad_masks): + pad_vid_frames[:vid_frames.shape[0], :, :vid_frames. + shape[2], :vid_frames.shape[3]].copy_(vid_frames) + vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames. + shape[3]] = False + # transpose the temporal and batch dims and create a NestedTensor: + return NestedTensor( + padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def interpolate(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split('.')[1]) < 7.0: + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, + mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, + mode, align_corners) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py new file mode 100644 index 00000000..e603df6c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py @@ -0,0 +1,128 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from .backbone import init_backbone +from .misc import NestedTensor +from .multimodal_transformer import MultimodalTransformer +from .segmentation import FPNSpatialDecoder + + +class MTTR(nn.Module): + """ The main module of the Multimodal Tracking Transformer """ + + def __init__(self, + num_queries, + mask_kernels_dim=8, + aux_loss=False, + **kwargs): + """ + Parameters: + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + MTTR can detect in a single image. In our paper we use 50 in all settings. + mask_kernels_dim: dim of the segmentation kernels and of the feature maps outputted by the spatial decoder. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.backbone = init_backbone(**kwargs) + self.transformer = MultimodalTransformer(**kwargs) + d_model = self.transformer.d_model + self.is_referred_head = nn.Linear( + d_model, + 2) # binary 'is referred?' prediction head for object queries + self.instance_kernels_head = MLP( + d_model, d_model, output_dim=mask_kernels_dim, num_layers=2) + self.obj_queries = nn.Embedding( + num_queries, d_model) # pos embeddings for the object queries + self.vid_embed_proj = nn.Conv2d( + self.backbone.layer_output_channels[-1], d_model, kernel_size=1) + self.spatial_decoder = FPNSpatialDecoder( + d_model, self.backbone.layer_output_channels[:-1][::-1], + mask_kernels_dim) + self.aux_loss = aux_loss + + def forward(self, samples: NestedTensor, valid_indices, text_queries): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: Batched frames of shape [time x batch_size x 3 x H x W] + - samples.mask: A binary mask of shape [time x batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_is_referred": The reference prediction logits for all queries. + Shape: [time x batch_size x num_queries x 2] + - "pred_masks": The mask logits for all queries. + Shape: [time x batch_size x num_queries x H_mask x W_mask] + - "aux_outputs": Optional, only returned when auxiliary losses are activated. It is a list of + dictionaries containing the two above keys for each decoder layer. + """ + backbone_out = self.backbone(samples) + # keep only the valid frames (frames which are annotated): + # (for example, in a2d-sentences only the center frame in each window is annotated). + for layer_out in backbone_out: + layer_out.tensors = layer_out.tensors.index_select( + 0, valid_indices) + layer_out.mask = layer_out.mask.index_select(0, valid_indices) + bbone_final_layer_output = backbone_out[-1] + vid_embeds, vid_pad_mask = bbone_final_layer_output.decompose() + + T, B, _, _, _ = vid_embeds.shape + vid_embeds = rearrange(vid_embeds, 't b c h w -> (t b) c h w') + vid_embeds = self.vid_embed_proj(vid_embeds) + vid_embeds = rearrange( + vid_embeds, '(t b) c h w -> t b c h w', t=T, b=B) + + transformer_out = self.transformer(vid_embeds, vid_pad_mask, + text_queries, + self.obj_queries.weight) + # hs is: [L, T, B, N, D] where L is number of decoder layers + # vid_memory is: [T, B, D, H, W] + # txt_memory is a list of length T*B of [S, C] where S might be different for each sentence + # encoder_middle_layer_outputs is a list of [T, B, H, W, D] + hs, vid_memory, txt_memory = transformer_out + + vid_memory = rearrange(vid_memory, 't b d h w -> (t b) d h w') + bbone_middle_layer_outputs = [ + rearrange(o.tensors, 't b d h w -> (t b) d h w') + for o in backbone_out[:-1][::-1] + ] + decoded_frame_features = self.spatial_decoder( + vid_memory, bbone_middle_layer_outputs) + decoded_frame_features = rearrange( + decoded_frame_features, '(t b) d h w -> t b d h w', t=T, b=B) + instance_kernels = self.instance_kernels_head(hs) # [L, T, B, N, C] + # output masks is: [L, T, B, N, H_mask, W_mask] + output_masks = torch.einsum('ltbnc,tbchw->ltbnhw', instance_kernels, + decoded_frame_features) + outputs_is_referred = self.is_referred_head(hs) # [L, T, B, N, 2] + + layer_outputs = [] + for pm, pir in zip(output_masks, outputs_is_referred): + layer_out = {'pred_masks': pm, 'pred_is_referred': pir} + layer_outputs.append(layer_out) + out = layer_outputs[ + -1] # the output for the last decoder layer is used by default + if self.aux_loss: + out['aux_outputs'] = layer_outputs[:-1] + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py new file mode 100644 index 00000000..8c24e397 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -0,0 +1,440 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# MTTR Multimodal Transformer class. +# Modified from DETR https://github.com/facebookresearch/detr + +import copy +import os +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor, nn +from transformers import RobertaModel, RobertaTokenizerFast + +from .position_encoding_2d import PositionEmbeddingSine2D + +os.environ[ + 'TOKENIZERS_PARALLELISM'] = 'false' # this disables a huggingface tokenizer warning (printed every epoch) + + +class MultimodalTransformer(nn.Module): + + def __init__(self, + num_encoder_layers=3, + num_decoder_layers=3, + text_encoder_type='roberta-base', + freeze_text_encoder=True, + **kwargs): + super().__init__() + self.d_model = kwargs['d_model'] + encoder_layer = TransformerEncoderLayer(**kwargs) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + decoder_layer = TransformerDecoderLayer(**kwargs) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + norm=nn.LayerNorm(self.d_model), + return_intermediate=True) + self.pos_encoder_2d = PositionEmbeddingSine2D() + self._reset_parameters() + + self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) + self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... + self.tokenizer = RobertaTokenizerFast.from_pretrained( + text_encoder_type) + self.freeze_text_encoder = freeze_text_encoder + if freeze_text_encoder: + for p in self.text_encoder.parameters(): + p.requires_grad_(False) + + self.txt_proj = FeatureResizer( + input_feat_size=self.text_encoder.config.hidden_size, + output_feat_size=self.d_model, + dropout=kwargs['dropout'], + ) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, vid_embeds, vid_pad_mask, text_queries, obj_queries): + device = vid_embeds.device + t, b, _, h, w = vid_embeds.shape + + txt_memory, txt_pad_mask = self.forward_text(text_queries, device) + # add temporal dim to txt memory & padding mask: + txt_memory = repeat(txt_memory, 's b c -> s (t b) c', t=t) + txt_pad_mask = repeat(txt_pad_mask, 'b s -> (t b) s', t=t) + + vid_embeds = rearrange(vid_embeds, 't b c h w -> (h w) (t b) c') + # Concat the image & text embeddings on the sequence dimension + encoder_src_seq = torch.cat((vid_embeds, txt_memory), dim=0) + seq_mask = torch.cat( + (rearrange(vid_pad_mask, 't b h w -> (t b) (h w)'), txt_pad_mask), + dim=1) + # vid_pos_embed is: [T*B, H, W, d_model] + vid_pos_embed = self.pos_encoder_2d( + rearrange(vid_pad_mask, 't b h w -> (t b) h w'), self.d_model) + # use zeros in place of pos embeds for the text sequence: + pos_embed = torch.cat( + (rearrange(vid_pos_embed, 't_b h w c -> (h w) t_b c'), + torch.zeros_like(txt_memory)), + dim=0) + + memory = self.encoder( + encoder_src_seq, src_key_padding_mask=seq_mask, + pos=pos_embed) # [S, T*B, C] + vid_memory = rearrange( + memory[:h * w, :, :], + '(h w) (t b) c -> t b c h w', + h=h, + w=w, + t=t, + b=b) + txt_memory = memory[h * w:, :, :] + txt_memory = rearrange(txt_memory, 's t_b c -> t_b s c') + txt_memory = [ + t_mem[~pad_mask] + for t_mem, pad_mask in zip(txt_memory, txt_pad_mask) + ] # remove padding + + # add T*B dims to query embeds (was: [N, C], where N is the number of object queries): + obj_queries = repeat(obj_queries, 'n c -> n (t b) c', t=t, b=b) + tgt = torch.zeros_like(obj_queries) # [N, T*B, C] + + # hs is [L, N, T*B, C] where L is number of layers in the decoder + hs = self.decoder( + tgt, + memory, + memory_key_padding_mask=seq_mask, + pos=pos_embed, + query_pos=obj_queries) + hs = rearrange(hs, 'l n (t b) c -> l t b n c', t=t, b=b) + return hs, vid_memory, txt_memory + + def forward_text(self, text_queries, device): + tokenized_queries = self.tokenizer.batch_encode_plus( + text_queries, padding='longest', return_tensors='pt') + tokenized_queries = tokenized_queries.to(device) + with torch.inference_mode(mode=self.freeze_text_encoder): + encoded_text = self.text_encoder(**tokenized_queries) + # Transpose memory because pytorch's attention expects sequence first + txt_memory = rearrange(encoded_text.last_hidden_state, + 'b s c -> s b c') + txt_memory = self.txt_proj( + txt_memory) # change text embeddings dim to model dim + # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer + txt_pad_mask = tokenized_queries.attention_mask.ne(1).bool() # [B, S] + return txt_memory, txt_pad_mask + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, + k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, + k, + value=src2, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, + pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class FeatureResizer(nn.Module): + """ + This class takes as input a set of embeddings of dimension C1 and outputs a set of + embedding of dimension C2, after a linear transformation, dropout and normalization (LN). + """ + + def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): + super().__init__() + self.do_ln = do_ln + # Object feature encoding + self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) + self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, encoder_features): + x = self.fc(encoder_features) + if self.do_ln: + x = self.layer_norm(x) + output = self.dropout(x) + return output + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(F'activation should be relu/gelu, not {activation}.') diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py new file mode 100644 index 00000000..f9ef05a1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py @@ -0,0 +1,57 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# 2D sine positional encodings for the visual features in the multimodal transformer. + +import math + +import torch +from torch import Tensor, nn + + +class PositionEmbeddingSine2D(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, temperature=10000, normalize=True, scale=None): + super().__init__() + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask: Tensor, hidden_dim: int): + """ + @param mask: a tensor of shape [B, H, W] + @param hidden_dim: int + @return: + """ + num_pos_feats = hidden_dim // 2 + + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + return pos diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py new file mode 100644 index 00000000..64582140 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py @@ -0,0 +1,119 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import numpy as np +import pycocotools.mask as mask_util +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class A2DSentencesPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(A2DSentencesPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, resized_padded_sample_size, + resized_sample_sizes, orig_sample_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + resized_padded_sample_size: size of samples (input to model) after size augmentation + padding. + resized_sample_sizes: size of samples after size augmentation but without padding. + orig_sample_sizes: original size of the samples (no augmentations or padding) + """ + pred_is_referred = outputs['pred_is_referred'] + prob = F.softmax(pred_is_referred, dim=-1) + scores = prob[..., 0] + pred_masks = outputs['pred_masks'] + pred_masks = F.interpolate( + pred_masks, + size=resized_padded_sample_size, + mode='bilinear', + align_corners=False) + pred_masks = (pred_masks.sigmoid() > 0.5) + processed_pred_masks, rle_masks = [], [] + for f_pred_masks, resized_size, orig_size in zip( + pred_masks, resized_sample_sizes, orig_sample_sizes): + f_mask_h, f_mask_w = resized_size # resized shape without padding + # remove the samples' padding + f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, : + f_mask_w].unsqueeze(1) + # resize the samples back to their original dataset (target) size for evaluation + f_pred_masks_processed = F.interpolate( + f_pred_masks_no_pad.float(), size=orig_size, mode='nearest') + f_pred_rle_masks = [ + mask_util.encode( + np.array( + mask[0, :, :, np.newaxis], dtype=np.uint8, + order='F'))[0] + for mask in f_pred_masks_processed.cpu() + ] + processed_pred_masks.append(f_pred_masks_processed) + rle_masks.append(f_pred_rle_masks) + predictions = [{ + 'scores': s, + 'masks': m, + 'rle_masks': rle + } for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] + return predictions + + +class ReferYoutubeVOSPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(ReferYoutubeVOSPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, videos_metadata, samples_shape_with_padding): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + videos_metadata: a dictionary with each video's metadata. + samples_shape_with_padding: size of the batch frames with padding. + """ + pred_is_referred = outputs['pred_is_referred'] + prob_is_referred = F.softmax(pred_is_referred, dim=-1) + # note we average on the temporal dim to compute score per trajectory: + trajectory_scores = prob_is_referred[..., 0].mean(dim=0) + pred_trajectory_indices = torch.argmax(trajectory_scores, dim=-1) + pred_masks = rearrange(outputs['pred_masks'], + 't b nq h w -> b t nq h w') + # keep only the masks of the chosen trajectories: + b = pred_masks.shape[0] + pred_masks = pred_masks[torch.arange(b), :, pred_trajectory_indices] + # resize the predicted masks to the size of the model input (which might include padding) + pred_masks = F.interpolate( + pred_masks, + size=samples_shape_with_padding, + mode='bilinear', + align_corners=False) + # apply a threshold to create binary masks: + pred_masks = (pred_masks.sigmoid() > 0.5) + # remove the padding per video (as videos might have different resolutions and thus different padding): + preds_by_video = [] + for video_pred_masks, video_metadata in zip(pred_masks, + videos_metadata): + # size of the model input batch frames without padding: + resized_h, resized_w = video_metadata['resized_frame_size'] + video_pred_masks = video_pred_masks[:, :resized_h, : + resized_w].unsqueeze( + 1) # remove the padding + # resize the masks back to their original frames dataset size for evaluation: + original_frames_size = video_metadata['original_frame_size'] + tuple_size = tuple(original_frames_size.cpu().numpy()) + video_pred_masks = F.interpolate( + video_pred_masks.float(), size=tuple_size, mode='nearest') + video_pred_masks = video_pred_masks.to(torch.uint8).cpu() + # combine the predicted masks and the video metadata to create a final predictions dict: + video_pred = {**video_metadata, **{'pred_masks': video_pred_masks}} + preds_by_video.append(video_pred) + return preds_by_video diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py new file mode 100644 index 00000000..b3228820 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py @@ -0,0 +1,137 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class FPNSpatialDecoder(nn.Module): + """ + An FPN-like spatial decoder. Generates high-res, semantically rich features which serve as the base for creating + instance segmentation masks. + """ + + def __init__(self, context_dim, fpn_dims, mask_kernels_dim=8): + super().__init__() + + inter_dims = [ + context_dim, context_dim // 2, context_dim // 4, context_dim // 8, + context_dim // 16 + ] + self.lay1 = torch.nn.Conv2d(context_dim, inter_dims[0], 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, inter_dims[0]) + self.lay2 = torch.nn.Conv2d(inter_dims[0], inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.context_dim = context_dim + + self.add_extra_layer = len(fpn_dims) == 3 + if self.add_extra_layer: + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + self.lay5 = torch.nn.Conv2d( + inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d( + inter_dims[4], mask_kernels_dim, 3, padding=1) + else: + self.out_lay = torch.nn.Conv2d( + inter_dims[3], mask_kernels_dim, 3, padding=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, layer_features: List[Tensor]): + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(layer_features[0]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(layer_features[1]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + if self.add_extra_layer: + cur_fpn = self.adapter3(layer_features[2]) + x = cur_fpn + F.interpolate( + x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss(inputs, + targets, + num_masks, + alpha: float = 0.25, + gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction='none') + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t)**gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py new file mode 100644 index 00000000..9a08ef48 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py @@ -0,0 +1,731 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from Video-Swin-Transformer https://github.com/SwinTransformer/Video-Swin-Transformer + +from functools import lru_cache, reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], + window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, + 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], + W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention3D(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + wd, wh, ww = window_size + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, + coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[ + 0], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[1] < self.window_size[ + 1], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[2] < self.window_size[ + 2], 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition(shifted_x, + window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C, ))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, + Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +# cache each stage results +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], + -shift_size[0]), slice( + -shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], + -shift_size[1]), slice( + -shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], + -shift_size[2]), slice( + -shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, + window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) for i in range(depth) + ]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad( + x, + (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + + +class SwinTransformer3D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4, 4, 4), + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if i_layer < self.num_layers - 1 else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self, logger): + """Inflate the swin2d parameters to swin3d. + + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict[ + 'patch_embed.proj.weight'].unsqueeze(2).repeat( + 1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + wd = self.window_size[0] + if nH1 != nH2: + logger.warning(f'Error in loading {k}, passing') + else: + if L1 != L2: + S1 = int(L1**0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute( + 1, 0).view(1, nH1, S1, S1), + size=(2 * self.window_size[1] - 1, + 2 * self.window_size[2] - 1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( + nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat( + 2 * wd - 1, 1) + + msg = self.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer3D, self).train(mode) + self._freeze_stages() diff --git a/modelscope/outputs.py b/modelscope/outputs.py index a49ddacf..fbe15646 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -417,6 +417,12 @@ TASK_OUTPUTS = { # } Tasks.video_summarization: [OutputKeys.OUTPUT], + # referring video object segmentation result for a single video + # { + # "masks": [np.array # 2D array with shape [height, width]] + # } + Tasks.referring_video_object_segmentation: [OutputKeys.MASKS], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 174d10b1..8098bdec 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -202,6 +202,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), Tasks.product_segmentation: (Pipelines.product_segmentation, 'damo/cv_F3Net_product-segmentation'), + Tasks.referring_video_object_segmentation: + (Pipelines.referring_video_object_segmentation, + 'damo/cv_swin-t_referring_video-object-segmentation'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index f84f5fe5..97cd8761 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin from .hand_static_pipeline import HandStaticPipeline + from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline else: _import_structure = { @@ -128,6 +129,9 @@ else: ['FacialExpressionRecognitionPipeline'], 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], 'hand_static_pipeline': ['HandStaticPipeline'], + 'referring_video_object_segmentation_pipeline': [ + 'ReferringVideoObjectSegmentationPipeline' + ], } import sys diff --git a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py new file mode 100644 index 00000000..d264b386 --- /dev/null +++ b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py @@ -0,0 +1,193 @@ +# The implementation here is modified based on MTTR, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/mttr2021/MTTR +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from einops import rearrange +from moviepy.editor import AudioFileClip, ImageSequenceClip, VideoFileClip +from PIL import Image, ImageDraw, ImageFont, ImageOps +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.referring_video_object_segmentation, + module_name=Pipelines.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a referring video object segmentation pipeline for prediction + + Args: + model: model id on modelscope hub + """ + _device = kwargs.pop('device', 'gpu') + if torch.cuda.is_available() and _device == 'gpu': + self.device = 'gpu' + else: + self.device = 'cpu' + super().__init__(model=model, device=self.device, **kwargs) + + logger.info('Load model done!') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + + Args: + input: path of the input video + + """ + assert isinstance(input, tuple) and len( + input + ) == 4, 'error - input type must be tuple and input length must be 4' + self.input_video_pth, text_queries, start_pt, end_pt = input + + assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long' + assert 1 <= len( + text_queries) <= 2, 'error - 1-2 input text queries are expected' + + # extract the relevant subclip: + self.input_clip_pth = 'input_clip.mp4' + with VideoFileClip(self.input_video_pth) as video: + subclip = video.subclip(start_pt, end_pt) + subclip.write_videofile(self.input_clip_pth) + + self.window_length = 24 # length of window during inference + self.window_overlap = 6 # overlap (in frames) between consecutive windows + + self.video, audio, self.meta = torchvision.io.read_video( + filename=self.input_clip_pth) + self.video = rearrange(self.video, 't h w c -> t c h w') + + input_video = F.resize(self.video, size=360, max_size=640) + if self.device_name == 'gpu': + input_video = input_video.cuda() + + input_video = input_video.to(torch.float).div_(255) + input_video = F.normalize( + input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + video_metadata = { + 'resized_frame_size': input_video.shape[-2:], + 'original_frame_size': self.video.shape[-2:] + } + + # partition the clip into overlapping windows of frames: + windows = [ + input_video[i:i + self.window_length] + for i in range(0, len(input_video), self.window_length + - self.window_overlap) + ] + # clean up the text queries: + self.text_queries = [' '.join(q.lower().split()) for q in text_queries] + + result = { + 'text_queries': self.text_queries, + 'windows': windows, + 'video_metadata': video_metadata + } + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + pred_masks_per_query = [] + t, _, h, w = self.video.shape + for text_query in tqdm(input['text_queries'], desc='text queries'): + pred_masks = torch.zeros(size=(t, 1, h, w)) + for i, window in enumerate( + tqdm(input['windows'], desc='windows')): + + window_masks = self.model.inference( + window=window, + text_query=text_query, + metadata=input['video_metadata']) + + win_start_idx = i * ( + self.window_length - self.window_overlap) + pred_masks[win_start_idx:win_start_idx + + self.window_length] = window_masks + pred_masks_per_query.append(pred_masks) + return pred_masks_per_query + + def postprocess(self, inputs) -> Dict[str, Any]: + if self.model.cfg.pipeline.save_masked_video: + # RGB colors for instance masks: + light_blue = (41, 171, 226) + purple = (237, 30, 121) + dark_green = (35, 161, 90) + orange = (255, 148, 59) + colors = np.array([light_blue, purple, dark_green, orange]) + + # width (in pixels) of the black strip above the video on which the text queries will be displayed: + text_border_height_per_query = 36 + + video_np = rearrange(self.video, + 't c h w -> t h w c').numpy() / 255.0 + + # del video + pred_masks_per_frame = rearrange( + torch.stack(inputs), 'q t 1 h w -> t q h w').numpy() + masked_video = [] + for vid_frame, frame_masks in tqdm( + zip(video_np, pred_masks_per_frame), + total=len(video_np), + desc='applying masks...'): + # apply the masks: + for inst_mask, color in zip(frame_masks, colors): + vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) + vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) + # visualize the text queries: + vid_frame = ImageOps.expand( + vid_frame, + border=(0, len(self.text_queries) + * text_border_height_per_query, 0, 0)) + W, H = vid_frame.size + draw = ImageDraw.Draw(vid_frame) + font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) + for i, (text_query, color) in enumerate( + zip(self.text_queries, colors), start=1): + w, h = draw.textsize(text_query, font=font) + draw.text(((W - w) / 2, + (text_border_height_per_query * i) - h - 3), + text_query, + fill=tuple(color) + (255, ), + font=font) + masked_video.append(np.array(vid_frame)) + print(type(vid_frame)) + print(type(masked_video[0])) + print(masked_video[0].shape) + # generate and save the output clip: + + assert self.model.cfg.pipeline.output_path + output_clip_path = self.model.cfg.pipeline.output_path + clip = ImageSequenceClip( + sequence=masked_video, fps=self.meta['video_fps']) + clip = clip.set_audio(AudioFileClip(self.input_clip_pth)) + clip.write_videofile( + output_clip_path, fps=self.meta['video_fps'], audio=True) + del masked_video + + result = {OutputKeys.MASKS: inputs} + return result + + +def apply_mask(image, mask, color, transparency=0.7): + mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) + mask = mask * transparency + color_matrix = np.ones(image.shape, dtype=np.float) * color + out_image = color_matrix * mask + image * (1.0 - mask) + return out_image diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 0eb369da..6ba58c19 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -80,6 +80,9 @@ class CVTasks(object): virtual_try_on = 'virtual-try-on' movie_scene_segmentation = 'movie-scene-segmentation' + # video segmentation + referring_video_object_segmentation = 'referring-video-object-segmentation' + # video editing video_inpainting = 'video-inpainting' diff --git a/requirements/cv.txt b/requirements/cv.txt index eb38beb1..d23fab3a 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,4 +1,5 @@ albumentations>=1.0.3 +av>=9.2.0 easydict fairscale>=0.4.1 fastai>=1.0.51 @@ -14,6 +15,7 @@ lpips ml_collections mmcls>=0.21.0 mmdet>=2.25.0 +moviepy>=1.0.3 networkx>=2.5 numba onnxruntime>=1.10 diff --git a/tests/pipelines/test_referring_video_object_segmentation.py b/tests/pipelines/test_referring_video_object_segmentation.py new file mode 100644 index 00000000..3e81d9c3 --- /dev/null +++ b/tests/pipelines/test_referring_video_object_segmentation.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ReferringVideoObjectSegmentationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.referring_video_object_segmentation + self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_referring_video_object_segmentation(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline( + Tasks.referring_video_object_segmentation, model=self.model_id) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_referring_video_object_segmentation_with_default_task(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline(Tasks.referring_video_object_segmentation) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()