From 674e625e7cc61b6876713e7485bf3203a23f961a Mon Sep 17 00:00:00 2001 From: "wendi.hwd" Date: Fri, 29 Jul 2022 11:48:51 +0800 Subject: [PATCH] [to #42322933]cv_det_branch_to_master Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9490080 --- data/test/images/image_detection.jpg | 3 + modelscope/metainfo.py | 3 + modelscope/models/cv/__init__.py | 3 +- .../models/cv/object_detection/__init__.py | 22 + .../models/cv/object_detection/mmdet_model.py | 92 +++ .../cv/object_detection/mmdet_ms/__init__.py | 4 + .../mmdet_ms/backbones/__init__.py | 3 + .../mmdet_ms/backbones/vit.py | 626 ++++++++++++++++++ .../mmdet_ms/dense_heads/__init__.py | 4 + .../mmdet_ms/dense_heads/anchor_head.py | 48 ++ .../mmdet_ms/dense_heads/rpn_head.py | 268 ++++++++ .../mmdet_ms/necks/__init__.py | 3 + .../cv/object_detection/mmdet_ms/necks/fpn.py | 207 ++++++ .../mmdet_ms/roi_heads/__init__.py | 8 + .../mmdet_ms/roi_heads/bbox_heads/__init__.py | 4 + .../roi_heads/bbox_heads/convfc_bbox_head.py | 229 +++++++ .../mmdet_ms/roi_heads/mask_heads/__init__.py | 3 + .../roi_heads/mask_heads/fcn_mask_head.py | 414 ++++++++++++ .../mmdet_ms/utils/__init__.py | 4 + .../mmdet_ms/utils/checkpoint.py | 558 ++++++++++++++++ .../mmdet_ms/utils/convModule_norm.py | 30 + modelscope/pipelines/base.py | 7 +- modelscope/pipelines/builder.py | 4 + modelscope/pipelines/cv/__init__.py | 2 + .../pipelines/cv/object_detection_pipeline.py | 51 ++ modelscope/utils/constant.py | 1 + requirements/cv.txt | 1 + tests/pipelines/test_object_detection.py | 56 ++ 28 files changed, 2655 insertions(+), 3 deletions(-) create mode 100644 data/test/images/image_detection.jpg create mode 100644 modelscope/models/cv/object_detection/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_model.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py create mode 100644 modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py create mode 100644 modelscope/pipelines/cv/object_detection_pipeline.py create mode 100644 tests/pipelines/test_object_detection.py diff --git a/data/test/images/image_detection.jpg b/data/test/images/image_detection.jpg new file mode 100644 index 00000000..37447ce3 --- /dev/null +++ b/data/test/images/image_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0218020651b6cdcc0051563f75750c8200d34fc49bf34cc053cd59c1f13cad03 +size 128624 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 20aa3586..9c5ed709 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -10,6 +10,7 @@ class Models(object): Model name should only contain model info but not task info. """ # vision models + detection = 'detection' scrfd = 'scrfd' classification_model = 'ClassificationModel' nafnet = 'nafnet' @@ -69,6 +70,8 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + human_detection = 'resnet18-human-detection' + object_detection = 'vit-object-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' live_category = 'live-category' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 076e1f4e..a96c6370 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -3,4 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon, cmdssl_video_embedding, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, - image_to_image_translation, super_resolution, virual_tryon) + image_to_image_translation, object_detection, super_resolution, + virual_tryon) diff --git a/modelscope/models/cv/object_detection/__init__.py b/modelscope/models/cv/object_detection/__init__.py new file mode 100644 index 00000000..fa73686d --- /dev/null +++ b/modelscope/models/cv/object_detection/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mmdet_model import DetectionModel + +else: + _import_structure = { + 'mmdet_model': ['DetectionModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/object_detection/mmdet_model.py b/modelscope/models/cv/object_detection/mmdet_model.py new file mode 100644 index 00000000..cc01b60c --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_model.py @@ -0,0 +1,92 @@ +import os.path as osp + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .mmdet_ms.backbones import ViT +from .mmdet_ms.dense_heads import RPNNHead +from .mmdet_ms.necks import FPNF +from .mmdet_ms.roi_heads import FCNMaskNHead, Shared4Conv1FCBBoxNHead + + +@MODELS.register_module(Tasks.human_detection, module_name=Models.detection) +@MODELS.register_module(Tasks.object_detection, module_name=Models.detection) +class DetectionModel(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + from mmcv.runner import load_checkpoint + from mmdet.datasets import replace_ImageToTensor + from mmdet.datasets.pipelines import Compose + from mmdet.models import build_detector + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + config_path = osp.join(model_dir, 'mmcv_config.py') + config = Config.from_file(config_path) + config.model.pretrained = None + self.model = build_detector(config.model) + + checkpoint = load_checkpoint( + self.model, model_path, map_location='cpu') + self.class_names = checkpoint['meta']['CLASSES'] + config.test_pipeline[0].type = 'LoadImageFromWebcam' + self.test_pipeline = Compose( + replace_ImageToTensor(config.test_pipeline)) + self.model.cfg = config + self.model.eval() + self.score_thr = config.score_thr + + def inference(self, data): + """data is dict,contain img and img_metas,follow with mmdet.""" + + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **data) + return results + + def preprocess(self, image): + """image is numpy return is dict contain img and img_metas,follow with mmdet.""" + + from mmcv.parallel import collate, scatter + data = dict(img=image) + data = self.test_pipeline(data) + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + + if next(self.model.parameters()).is_cuda: + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def postprocess(self, inputs): + + if isinstance(inputs[0], tuple): + bbox_result, _ = inputs[0] + else: + bbox_result, _ = inputs[0], None + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + bbox_result = np.vstack(bbox_result) + scores = bbox_result[:, -1] + inds = scores > self.score_thr + if np.sum(np.array(inds).astype('int')) == 0: + return None, None, None + bboxes = bbox_result[inds, :] + labels = labels[inds] + scores = bboxes[:, 4] + bboxes = bboxes[:, 0:4] + labels = [self.class_names[i_label] for i_label in labels] + return bboxes, scores, labels diff --git a/modelscope/models/cv/object_detection/mmdet_ms/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/__init__.py new file mode 100644 index 00000000..2e47ce76 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/__init__.py @@ -0,0 +1,4 @@ +from .backbones import ViT +from .dense_heads import AnchorNHead, RPNNHead +from .necks import FPNF +from .utils import ConvModule_Norm, load_checkpoint diff --git a/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py new file mode 100644 index 00000000..3b34dad6 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py @@ -0,0 +1,3 @@ +from .vit import ViT + +__all__ = ['ViT'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py b/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py new file mode 100644 index 00000000..53bda358 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py @@ -0,0 +1,626 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm, mmseg, setr, xcit and swin code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/fudan-zvg/SETR +# https://github.com/facebookresearch/xcit/ +# https://github.com/microsoft/Swin-Transformer +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +from ..utils import load_checkpoint + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + self.window_size = window_size + q_size = window_size[0] + rel_sp_dim = 2 * q_size - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W, rel_pos_bias=None): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = calc_rel_pos_spatial(attn, q, self.window_size, + self.window_size, self.rel_pos_h, + self.rel_pos_w) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def calc_rel_pos_spatial( + attn, + q, + q_shape, + k_shape, + rel_pos_h, + rel_pos_w, +): + """ + Spatial Relative Positional Embeddings. + """ + sp_idx = 0 + q_h, q_w = q_shape + k_h, k_w = k_shape + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = ( + torch.arange(q_h)[:, None] * q_h_ratio + - torch.arange(k_h)[None, :] * k_h_ratio) + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = ( + torch.arange(q_w)[:, None] * q_w_ratio + - torch.arange(k_w)[None, :] * k_w_ratio) + dist_w += (k_w - 1) * k_w_ratio + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + B, n_head, q_N, dim = q.shape + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) + rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) + rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view( + B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + attn_head_dim=None): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + q_size = window_size[0] + rel_sp_dim = 2 * q_size - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + x = x.reshape(B_, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.window_size[1] + - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] + - H % self.window_size[0]) % self.window_size[0] + + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition( + x, self.window_size[0]) # nW*B, window_size, window_size, C + x = x.view(-1, self.window_size[1] * self.window_size[0], + C) # nW*B, window_size*window_size, C + B_w = x.shape[0] + N_w = x.shape[1] + qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = calc_rel_pos_spatial(attn, q, self.window_size, + self.window_size, self.rel_pos_h, + self.rel_pos_w) + + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.view(-1, self.window_size[1], self.window_size[0], C) + x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B_, H * W, C) + + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + window=False): + super().__init__() + self.norm1 = norm_layer(dim) + if not window: + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + else: + self.attn = WindowAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, H, W): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class Norm2d(nn.Module): + + def __init__(self, embed_dim): + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=1e-6) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +@BACKBONES.register_module() +class ViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=80, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=None, + init_values=None, + use_checkpoint=False, + use_abs_pos_emb=False, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + out_indices=[11], + interval=3, + pretrained=None): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.out_indices = out_indices + + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) + else: + self.pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=(14, 14) if + ((i + 1) % interval != 0) else self.patch_embed.patch_shape, + window=((i + 1) % interval != 0)) for i in range(depth) + ]) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self.norm = norm_layer(embed_dim) + + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + Norm2d(embed_dim), + nn.GELU(), + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, + stride=2), ) + + self.fpn3 = nn.Identity() + + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.apply(self._init_weights) + self.fix_init_weight() + self.pretrained = pretrained + + def fix_init_weight(self): + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + pretrained = pretrained or self.pretrained + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + print(f'load from {pretrained}') + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + for i, blk in enumerate(self.blocks): + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, Hp, Wp) + + x = self.norm(x) + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp) + + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(ops)): + features.append(ops[i](xp)) + + return tuple(features) + + def forward(self, x): + + x = self.forward_features(x) + + return x diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py new file mode 100644 index 00000000..0fba8c00 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py @@ -0,0 +1,4 @@ +from .anchor_head import AnchorNHead +from .rpn_head import RPNNHead + +__all__ = ['AnchorNHead', 'RPNNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py new file mode 100644 index 00000000..b4114652 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +from mmdet.models.builder import HEADS +from mmdet.models.dense_heads import AnchorHead + + +@HEADS.register_module() +class AnchorNHead(AnchorHead): + """Anchor-based head (RPN, RetinaNet, SSD, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes, + in_channels, + feat_channels, + anchor_generator=None, + bbox_coder=None, + reg_decoded_bbox=False, + loss_cls=None, + loss_bbox=None, + train_cfg=None, + test_cfg=None, + norm_cfg=None, + init_cfg=None): + self.norm_cfg = norm_cfg + super(AnchorNHead, + self).__init__(num_classes, in_channels, feat_channels, + anchor_generator, bbox_coder, reg_decoded_bbox, + loss_cls, loss_bbox, train_cfg, test_cfg, + init_cfg) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py new file mode 100644 index 00000000..f53368ce --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py @@ -0,0 +1,268 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import batched_nms +from mmdet.models.builder import HEADS + +from ..utils import ConvModule_Norm +from .anchor_head import AnchorNHead + + +@HEADS.register_module() +class RPNNHead(AnchorNHead): + """RPN head. + + Args: + in_channels (int): Number of channels in the input feature map. + init_cfg (dict or list[dict], optional): Initialization config dict. + num_convs (int): Number of convolution layers in the head. Default 1. + """ # noqa: W605 + + def __init__(self, + in_channels, + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01), + num_convs=1, + **kwargs): + self.num_convs = num_convs + super(RPNNHead, self).__init__( + 1, in_channels, init_cfg=init_cfg, **kwargs) + + def _init_layers(self): + """Initialize layers of the head.""" + if self.num_convs > 1: + rpn_convs = [] + for i in range(self.num_convs): + if i == 0: + in_channels = self.in_channels + else: + in_channels = self.feat_channels + # use ``inplace=False`` to avoid error: one of the variables + # needed for gradient computation has been modified by an + # inplace operation. + rpn_convs.append( + ConvModule_Norm( + in_channels, + self.feat_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + inplace=False)) + self.rpn_conv = nn.Sequential(*rpn_convs) + else: + self.rpn_conv = nn.Conv2d( + self.in_channels, self.feat_channels, 3, padding=1) + self.rpn_cls = nn.Conv2d(self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 1) + self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4, + 1) + + def forward_single(self, x): + """Forward feature map of a single scale level.""" + x = self.rpn_conv(x) + x = F.relu(x, inplace=True) + rpn_cls_score = self.rpn_cls(x) + rpn_bbox_pred = self.rpn_reg(x) + return rpn_cls_score, rpn_bbox_pred + + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + losses = super(RPNNHead, self).loss( + cls_scores, + bbox_preds, + gt_bboxes, + None, + img_metas, + gt_bboxes_ignore=gt_bboxes_ignore) + return dict( + loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + mlvl_anchors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_anchors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has + shape (num_anchors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RPN head does not need this value. + mlvl_anchors (list[Tensor]): Anchors of all scale level + each item has shape (num_anchors, 4). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + + # bboxes from different level should be independent during NMS, + # level_ids are used as labels for batched NMS to separate them + level_ids = [] + mlvl_scores = [] + mlvl_bbox_preds = [] + mlvl_valid_anchors = [] + nms_pre = cfg.get('nms_pre', -1) + for level_idx in range(len(cls_score_list)): + rpn_cls_score = cls_score_list[level_idx] + rpn_bbox_pred = bbox_pred_list[level_idx] + assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] + rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + if self.use_sigmoid_cls: + rpn_cls_score = rpn_cls_score.reshape(-1) + scores = rpn_cls_score.sigmoid() + else: + rpn_cls_score = rpn_cls_score.reshape(-1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = rpn_cls_score.softmax(dim=1)[:, 0] + rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + + anchors = mlvl_anchors[level_idx] + if 0 < nms_pre < scores.shape[0]: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + + mlvl_scores.append(scores) + mlvl_bbox_preds.append(rpn_bbox_pred) + mlvl_valid_anchors.append(anchors) + level_ids.append( + scores.new_full((scores.size(0), ), + level_idx, + dtype=torch.long)) + + return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds, + mlvl_valid_anchors, level_ids, cfg, + img_shape) + + def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors, + level_ids, cfg, img_shape, **kwargs): + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + mlvl_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_bboxes, num_class). + mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale + levels of a single image, each item has shape (num_bboxes, 4). + mlvl_valid_anchors (list[Tensor]): Anchors of all scale level + each item has shape (num_bboxes, 4). + level_ids (list[Tensor]): Indexes from all scale levels of a + single image, each item has shape (num_bboxes, ). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + img_shape (tuple(int)): Shape of current image. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ + scores = torch.cat(mlvl_scores) + anchors = torch.cat(mlvl_valid_anchors) + rpn_bbox_pred = torch.cat(mlvl_bboxes) + proposals = self.bbox_coder.decode( + anchors, rpn_bbox_pred, max_shape=img_shape) + ids = torch.cat(level_ids) + + if cfg.min_bbox_size >= 0: + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + proposals = proposals[valid_mask] + scores = scores[valid_mask] + ids = ids[valid_mask] + + if proposals.numel() > 0: + dets, _ = batched_nms(proposals, scores, ids, cfg.nms) + else: + return proposals.new_zeros(0, 5) + + return dets[:cfg.max_per_img] + + def onnx_export(self, x, img_metas): + """Test without augmentation. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + img_metas (list[dict]): Meta info of each image. + Returns: + Tensor: dets of shape [N, num_det, 5]. + """ + cls_scores, bbox_preds = self(x) + + assert len(cls_scores) == len(bbox_preds) + + batch_bboxes, batch_scores = super(RPNNHead, self).onnx_export( + cls_scores, bbox_preds, img_metas=img_metas, with_nms=False) + # Use ONNX::NonMaxSuppression in deployment + from mmdet.core.export import add_dummy_nms_for_onnx + cfg = copy.deepcopy(self.test_cfg) + score_threshold = cfg.nms.get('score_thr', 0.0) + nms_pre = cfg.get('deploy_nms_pre', -1) + # Different from the normal forward doing NMS level by level, + # we do NMS across all levels when exporting ONNX. + dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores, + cfg.max_per_img, + cfg.nms.iou_threshold, + score_threshold, nms_pre, + cfg.max_per_img) + return dets diff --git a/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py new file mode 100644 index 00000000..5b0b6210 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py @@ -0,0 +1,3 @@ +from .fpn import FPNF + +__all__ = ['FPNF'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py b/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py new file mode 100644 index 00000000..52529b28 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +import torch.nn as nn +import torch.nn.functional as F +from mmcv.runner import BaseModule, auto_fp16 +from mmdet.models.builder import NECKS + +from ..utils import ConvModule_Norm + + +@NECKS.register_module() +class FPNF(BaseModule): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + use_residual=True, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(FPNF, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + self.use_residual = use_residual + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule_Norm( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule_Norm( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule_Norm( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + @auto_fp16() + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + if self.use_residual: + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + + return tuple(outs) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py new file mode 100644 index 00000000..a6be3775 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py @@ -0,0 +1,8 @@ +from .bbox_heads import (ConvFCBBoxNHead, Shared2FCBBoxNHead, + Shared4Conv1FCBBoxNHead) +from .mask_heads import FCNMaskNHead + +__all__ = [ + 'ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead', + 'FCNMaskNHead' +] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py new file mode 100644 index 00000000..0d4d5b6b --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py @@ -0,0 +1,4 @@ +from .convfc_bbox_head import (ConvFCBBoxNHead, Shared2FCBBoxNHead, + Shared4Conv1FCBBoxNHead) + +__all__ = ['ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py new file mode 100644 index 00000000..d2e04b80 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +import torch.nn as nn +from mmdet.models.builder import HEADS +from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead +from mmdet.models.utils import build_linear_layer + +from ...utils import ConvModule_Norm + + +@HEADS.register_module() +class ConvFCBBoxNHead(BBoxHead): + r"""More general bbox head, with shared conv and fc layers and two optional + separated branches. + + .. code-block:: none + + /-> cls convs -> cls fcs -> cls + shared convs -> shared fcs + \-> reg convs -> reg fcs -> reg + """ # noqa: W605 + + def __init__(self, + num_shared_convs=0, + num_shared_fcs=0, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + conv_out_channels=256, + fc_out_channels=1024, + conv_cfg=None, + norm_cfg=None, + init_cfg=None, + *args, + **kwargs): + super(ConvFCBBoxNHead, self).__init__( + *args, init_cfg=init_cfg, **kwargs) + assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs + + num_reg_convs + num_reg_fcs > 0) + if num_cls_convs > 0 or num_reg_convs > 0: + assert num_shared_fcs == 0 + if not self.with_cls: + assert num_cls_convs == 0 and num_cls_fcs == 0 + if not self.with_reg: + assert num_reg_convs == 0 and num_reg_fcs == 0 + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_cls_convs = num_cls_convs + self.num_cls_fcs = num_cls_fcs + self.num_reg_convs = num_reg_convs + self.num_reg_fcs = num_reg_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # add shared convs and fcs + self.shared_convs, self.shared_fcs, last_layer_dim = \ + self._add_conv_fc_branch( + self.num_shared_convs, self.num_shared_fcs, self.in_channels, + True) + self.shared_out_channels = last_layer_dim + + # add cls specific branch + self.cls_convs, self.cls_fcs, self.cls_last_dim = \ + self._add_conv_fc_branch( + self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) + + # add reg specific branch + self.reg_convs, self.reg_fcs, self.reg_last_dim = \ + self._add_conv_fc_branch( + self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) + + if self.num_shared_fcs == 0 and not self.with_avg_pool: + if self.num_cls_fcs == 0: + self.cls_last_dim *= self.roi_feat_area + if self.num_reg_fcs == 0: + self.reg_last_dim *= self.roi_feat_area + + self.relu = nn.ReLU(inplace=True) + # reconstruct fc_cls and fc_reg since input channels are changed + if self.with_cls: + if self.custom_cls_channels: + cls_channels = self.loss_cls.get_cls_channels(self.num_classes) + else: + cls_channels = self.num_classes + 1 + self.fc_cls = build_linear_layer( + self.cls_predictor_cfg, + in_features=self.cls_last_dim, + out_features=cls_channels) + if self.with_reg: + out_dim_reg = (4 if self.reg_class_agnostic else 4 + * self.num_classes) + self.fc_reg = build_linear_layer( + self.reg_predictor_cfg, + in_features=self.reg_last_dim, + out_features=out_dim_reg) + + if init_cfg is None: + # when init_cfg is None, + # It has been set to + # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], + # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] + # after `super(ConvFCBBoxHead, self).__init__()` + # we only need to append additional configuration + # for `shared_fcs`, `cls_fcs` and `reg_fcs` + self.init_cfg += [ + dict( + type='Xavier', + override=[ + dict(name='shared_fcs'), + dict(name='cls_fcs'), + dict(name='reg_fcs') + ]) + ] + + def _add_conv_fc_branch(self, + num_branch_convs, + num_branch_fcs, + in_channels, + is_shared=False): + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule_Norm( + conv_in_channels, + self.conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + last_layer_dim = self.conv_out_channels + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + # for shared branch, only consider self.with_avg_pool + # for separated branches, also consider self.num_shared_fcs + if (is_shared + or self.num_shared_fcs == 0) and not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x): + # shared part + if self.num_shared_convs > 0: + for conv in self.shared_convs: + x = conv(x) + + if self.num_shared_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + + x = x.flatten(1) + + for fc in self.shared_fcs: + x = self.relu(fc(x)) + # separate branches + x_cls = x + x_reg = x + + for conv in self.cls_convs: + x_cls = conv(x_cls) + if x_cls.dim() > 2: + if self.with_avg_pool: + x_cls = self.avg_pool(x_cls) + x_cls = x_cls.flatten(1) + for fc in self.cls_fcs: + x_cls = self.relu(fc(x_cls)) + + for conv in self.reg_convs: + x_reg = conv(x_reg) + if x_reg.dim() > 2: + if self.with_avg_pool: + x_reg = self.avg_pool(x_reg) + x_reg = x_reg.flatten(1) + for fc in self.reg_fcs: + x_reg = self.relu(fc(x_reg)) + + cls_score = self.fc_cls(x_cls) if self.with_cls else None + bbox_pred = self.fc_reg(x_reg) if self.with_reg else None + return cls_score, bbox_pred + + +@HEADS.register_module() +class Shared2FCBBoxNHead(ConvFCBBoxNHead): + + def __init__(self, fc_out_channels=1024, *args, **kwargs): + super(Shared2FCBBoxNHead, self).__init__( + num_shared_convs=0, + num_shared_fcs=2, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) + + +@HEADS.register_module() +class Shared4Conv1FCBBoxNHead(ConvFCBBoxNHead): + + def __init__(self, fc_out_channels=1024, *args, **kwargs): + super(Shared4Conv1FCBBoxNHead, self).__init__( + num_shared_convs=4, + num_shared_fcs=1, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py new file mode 100644 index 00000000..8f816850 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py @@ -0,0 +1,3 @@ +from .fcn_mask_head import FCNMaskNHead + +__all__ = ['FCNMaskNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py new file mode 100644 index 00000000..e5aedc98 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py @@ -0,0 +1,414 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer +from mmcv.ops.carafe import CARAFEPack +from mmcv.runner import BaseModule, ModuleList, auto_fp16, force_fp32 +from mmdet.core import mask_target +from mmdet.models.builder import HEADS, build_loss +from torch.nn.modules.utils import _pair + +from ...utils import ConvModule_Norm + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +@HEADS.register_module() +class FCNMaskNHead(BaseModule): + + def __init__(self, + num_convs=4, + roi_feat_size=14, + in_channels=256, + conv_kernel_size=3, + conv_out_channels=256, + num_classes=80, + class_agnostic=False, + upsample_cfg=dict(type='deconv', scale_factor=2), + conv_cfg=None, + norm_cfg=None, + predictor_cfg=dict(type='Conv'), + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + init_cfg=None): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super(FCNMaskNHead, self).__init__(init_cfg) + self.upsample_cfg = upsample_cfg.copy() + if self.upsample_cfg['type'] not in [ + None, 'deconv', 'nearest', 'bilinear', 'carafe' + ]: + raise ValueError( + f'Invalid upsample method {self.upsample_cfg["type"]}, ' + 'accepted methods are "deconv", "nearest", "bilinear", ' + '"carafe"') + self.num_convs = num_convs + # WARN: roi_feat_size is reserved and not used + self.roi_feat_size = _pair(roi_feat_size) + self.in_channels = in_channels + self.conv_kernel_size = conv_kernel_size + self.conv_out_channels = conv_out_channels + self.upsample_method = self.upsample_cfg.get('type') + self.scale_factor = self.upsample_cfg.pop('scale_factor', None) + self.num_classes = num_classes + self.class_agnostic = class_agnostic + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.predictor_cfg = predictor_cfg + self.fp16_enabled = False + self.loss_mask = build_loss(loss_mask) + + self.convs = ModuleList() + for i in range(self.num_convs): + in_channels = ( + self.in_channels if i == 0 else self.conv_out_channels) + padding = (self.conv_kernel_size - 1) // 2 + self.convs.append( + ConvModule_Norm( + in_channels, + self.conv_out_channels, + self.conv_kernel_size, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + upsample_in_channels = ( + self.conv_out_channels if self.num_convs > 0 else in_channels) + upsample_cfg_ = self.upsample_cfg.copy() + if self.upsample_method is None: + self.upsample = None + elif self.upsample_method == 'deconv': + upsample_cfg_.update( + in_channels=upsample_in_channels, + out_channels=self.conv_out_channels, + kernel_size=self.scale_factor, + stride=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + elif self.upsample_method == 'carafe': + upsample_cfg_.update( + channels=upsample_in_channels, scale_factor=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + else: + # suppress warnings + align_corners = (None + if self.upsample_method == 'nearest' else False) + upsample_cfg_.update( + scale_factor=self.scale_factor, + mode=self.upsample_method, + align_corners=align_corners) + self.upsample = build_upsample_layer(upsample_cfg_) + + out_channels = 1 if self.class_agnostic else self.num_classes + logits_in_channel = ( + self.conv_out_channels + if self.upsample_method == 'deconv' else upsample_in_channels) + self.conv_logits = build_conv_layer(self.predictor_cfg, + logits_in_channel, out_channels, 1) + self.relu = nn.ReLU(inplace=True) + self.debug_imgs = None + + def init_weights(self): + super(FCNMaskNHead, self).init_weights() + for m in [self.upsample, self.conv_logits]: + if m is None: + continue + elif isinstance(m, CARAFEPack): + m.init_weights() + elif hasattr(m, 'weight') and hasattr(m, 'bias'): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + @auto_fp16() + def forward(self, x): + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_pred = self.conv_logits(x) + return mask_pred + + def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg): + pos_proposals = [res.pos_bboxes for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, + gt_masks, rcnn_train_cfg) + return mask_targets + + @force_fp32(apply_to=('mask_pred', )) + def loss(self, mask_pred, mask_targets, labels): + """ + Example: + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> # There are lots of variations depending on the configuration + >>> self = FCNMaskHead(num_classes=C, num_convs=1) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_pred = self.forward(inputs) + >>> sf = self.scale_factor + >>> labels = torch.randint(0, C, size=(N,)) + >>> # With the default properties the mask targets should indicate + >>> # a (potentially soft) single-class label + >>> mask_targets = torch.rand(N, H * sf, W * sf) + >>> loss = self.loss(mask_pred, mask_targets, labels) + >>> print('loss = {!r}'.format(loss)) + """ + loss = dict() + if mask_pred.size(0) == 0: + loss_mask = mask_pred.sum() + else: + if self.class_agnostic: + loss_mask = self.loss_mask(mask_pred, mask_targets, + torch.zeros_like(labels)) + else: + loss_mask = self.loss_mask(mask_pred, mask_targets, labels) + loss['loss_mask'] = loss_mask + return loss + + def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, + ori_shape, scale_factor, rescale): + """Get segmentation masks from mask_pred and bboxes. + + Args: + mask_pred (Tensor or ndarray): shape (n, #class, h, w). + For single-scale testing, mask_pred is the direct output of + model, whose type is Tensor, while for multi-scale testing, + it will be converted to numpy array outside of this method. + det_bboxes (Tensor): shape (n, 4/5) + det_labels (Tensor): shape (n, ) + rcnn_test_cfg (dict): rcnn testing config + ori_shape (Tuple): original image height and width, shape (2,) + scale_factor(ndarray | Tensor): If ``rescale is True``, box + coordinates are divided by this scale factor to fit + ``ori_shape``. + rescale (bool): If True, the resulting masks will be rescaled to + ``ori_shape``. + + Returns: + list[list]: encoded masks. The c-th item in the outer list + corresponds to the c-th class. Given the c-th outer list, the + i-th item in that inner list is the mask for the i-th box with + class label c. + + Example: + >>> import mmcv + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> self = FCNMaskHead(num_classes=C, num_convs=0) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_pred = self.forward(inputs) + >>> # Each input is associated with some bounding box + >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) + >>> det_labels = torch.randint(0, C, size=(N,)) + >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, }) + >>> ori_shape = (H * 4, W * 4) + >>> scale_factor = torch.FloatTensor((1, 1)) + >>> rescale = False + >>> # Encoded masks are a list for each category. + >>> encoded_masks = self.get_seg_masks( + >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape, + >>> scale_factor, rescale + >>> ) + >>> assert len(encoded_masks) == C + >>> assert sum(list(map(len, encoded_masks))) == N + """ + if isinstance(mask_pred, torch.Tensor): + mask_pred = mask_pred.sigmoid() + else: + # In AugTest, has been activated before + mask_pred = det_bboxes.new_tensor(mask_pred) + + device = mask_pred.device + cls_segms = [[] for _ in range(self.num_classes) + ] # BG is not included in num_classes + bboxes = det_bboxes[:, :4] + labels = det_labels + + # In most cases, scale_factor should have been + # converted to Tensor when rescale the bbox + if not isinstance(scale_factor, torch.Tensor): + if isinstance(scale_factor, float): + scale_factor = np.array([scale_factor] * 4) + warn('Scale_factor should be a Tensor or ndarray ' + 'with shape (4,), float would be deprecated. ') + assert isinstance(scale_factor, np.ndarray) + scale_factor = torch.Tensor(scale_factor) + + if rescale: + img_h, img_w = ori_shape[:2] + bboxes = bboxes / scale_factor.to(bboxes) + else: + w_scale, h_scale = scale_factor[0], scale_factor[1] + img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32) + img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32) + + N = len(mask_pred) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == 'cpu': + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + # the types of img_w and img_h are np.int32, + # when the image resolution is large, + # the calculation of num_chunks will overflow. + # so we need to change the types of img_w and img_h to int. + # See https://github.com/open-mmlab/mmdetection/pull/5191 + num_chunks = int( + np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT + / GPU_MEM_LIMIT)) + # assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it' + assert num_chunks <= N, 'Default GPU_MEM_LIMIT is too small; try increasing it' + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + N, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8) + + if not self.class_agnostic: + mask_pred = mask_pred[range(N), labels][:, None] + + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_pred[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == 'cpu') + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + im_mask[(inds, ) + spatial_inds] = masks_chunk + + for i in range(N): + cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy()) + return cls_segms + + def onnx_export(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, + ori_shape, **kwargs): + """Get segmentation masks from mask_pred and bboxes. + + Args: + mask_pred (Tensor): shape (n, #class, h, w). + det_bboxes (Tensor): shape (n, 4/5) + det_labels (Tensor): shape (n, ) + rcnn_test_cfg (dict): rcnn testing config + ori_shape (Tuple): original image height and width, shape (2,) + + Returns: + Tensor: a mask of shape (N, img_h, img_w). + """ + + mask_pred = mask_pred.sigmoid() + bboxes = det_bboxes[:, :4] + labels = det_labels + # No need to consider rescale and scale_factor while exporting to ONNX + img_h, img_w = ori_shape[:2] + threshold = rcnn_test_cfg.mask_thr_binary + if not self.class_agnostic: + box_inds = torch.arange(mask_pred.shape[0]) + mask_pred = mask_pred[box_inds, labels][:, None] + masks, _ = _do_paste_mask( + mask_pred, bboxes, img_h, img_w, skip_empty=False) + if threshold >= 0: + # should convert to float to avoid problems in TRT + masks = (masks >= threshold).to(dtype=torch.float) + return masks + + +def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + # IsInf op is not supported with ONNX<=1.7.0 + if not torch.onnx.is_in_onnx_export(): + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py new file mode 100644 index 00000000..971a0232 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py @@ -0,0 +1,4 @@ +from .checkpoint import load_checkpoint +from .convModule_norm import ConvModule_Norm + +__all__ = ['load_checkpoint', 'ConvModule_Norm'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py new file mode 100644 index 00000000..593af1cc --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py @@ -0,0 +1,558 @@ +# Copyright (c) Open-MMLab. All rights reserved. +# Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import mmcv +import torch +import torchvision +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.runner import get_dist_info +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.utils import model_zoo + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + print('finish load') + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None, + load_ema=True): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if load_ema and 'state_dict_ema' in checkpoint: + state_dict = checkpoint['state_dict_ema'] + # logger.info(f'loading from state_dict_ema') + logger.info('loading from state_dict_ema') + elif 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + # logger.info(f'loading from state_dict') + logger.info('loading from state_dict') + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + # logger.info(f'loading from model') + logger.info('loading from model') + print('loading from model') + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = { + k.replace('encoder.', ''): v + for k, v in state_dict.items() if k.startswith('encoder.') + } + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2) + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + + if 'relative_position_bias_table' in key: + state_dict.pop(key) + + if '.q_bias' in key: + q_bias = state_dict[key] + v_bias = state_dict[key.replace('q_bias', 'v_bias')] + qkv_bias = torch.cat([q_bias, torch.zeros_like(q_bias), v_bias], 0) + state_dict[key.replace('q_bias', 'qkv.bias')] = qkv_bias + + if '.v.bias' in key: + continue + + all_keys = list(state_dict.keys()) + new_state_dict = {} + for key in all_keys: + if 'qkv.bias' in key: + value = state_dict[key] + dim = value.shape[0] + selected_dim = (dim * 2) // 3 + new_state_dict[key.replace( + 'qkv.bias', 'pos_bias')] = state_dict[key][:selected_dim] + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + if table_key not in model.state_dict().keys(): + logger.warning( + 'relative_position_bias_table exits in pretrained model but not in current one, pass' + ) + continue + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + else: + if L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0) + rank, _ = get_dist_info() + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + H, W = model.patch_embed.patch_shape + num_patches = model.patch_embed.num_patches + num_extra_tokens = 1 + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + if rank == 0: + print('Position interpolate from %dx%d to %dx%d' % + (orig_size, orig_size, H, W)) + # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(H, W), mode='bicubic', align_corners=False) + new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py new file mode 100644 index 00000000..d81c24e1 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py @@ -0,0 +1,30 @@ +# Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet + +from mmcv.cnn import ConvModule + + +class ConvModule_Norm(ConvModule): + + def __init__(self, in_channels, out_channels, kernel, **kwargs): + super().__init__(in_channels, out_channels, kernel, **kwargs) + + self.normType = kwargs.get('norm_cfg', {'type': ''}) + if self.normType is not None: + self.normType = self.normType['type'] + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + if 'LN' in self.normType: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2).contiguous() + else: + x = self.norm(x) + elif layer == 'act' and activate and self.with_activation: + x = self.activate(x) + return x diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 8faf8691..b1d82557 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -62,6 +62,7 @@ class Pipeline(ABC): model: Union[InputModel, List[InputModel]] = None, preprocessor: Union[Preprocessor, List[Preprocessor]] = None, device: str = 'gpu', + auto_collate=True, **kwargs): """ Base class for pipeline. @@ -74,6 +75,7 @@ class Pipeline(ABC): model: (list of) Model name or model object preprocessor: (list of) Preprocessor object device (str): gpu device or cpu device to use + auto_collate (bool): automatically to convert data to tensor or not. """ if config_file is not None: self.cfg = Config.from_file(config_file) @@ -98,6 +100,7 @@ class Pipeline(ABC): self.device = create_device(self.device_name == 'cpu') self._model_prepare = False self._model_prepare_lock = Lock() + self._auto_collate = auto_collate def prepare_model(self): self._model_prepare_lock.acquire(timeout=600) @@ -252,7 +255,7 @@ class Pipeline(ABC): return self._collate_fn(torch.from_numpy(data)) elif isinstance(data, torch.Tensor): return data.to(self.device) - elif isinstance(data, (str, int, float, bool)): + elif isinstance(data, (str, int, float, bool, type(None))): return data elif isinstance(data, InputFeatures): return data @@ -270,7 +273,7 @@ class Pipeline(ABC): out = self.preprocess(input, **preprocess_params) with self.place_device(): - if self.framework == Frameworks.torch: + if self.framework == Frameworks.torch and self._auto_collate: with torch.no_grad(): out = self._collate_fn(out) out = self.forward(out, **forward_params) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 15a367b9..a0e5b5af 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -35,6 +35,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { ), # TODO: revise back after passing the pr Tasks.image_matting: (Pipelines.image_matting, 'damo/cv_unet_image-matting'), + Tasks.human_detection: (Pipelines.human_detection, + 'damo/cv_resnet18_human-detection'), + Tasks.object_detection: (Pipelines.object_detection, + 'damo/cv_vit_object-detection_coco'), Tasks.image_denoise: (Pipelines.image_denoise, 'damo/cv_nafnet_image-denoise_sidd'), Tasks.text_classification: (Pipelines.sentiment_analysis, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index f8a8f1d1..35230f08 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recognition_pipeline import AnimalRecognitionPipeline from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline + from .object_detection_pipeline import ObjectDetectionPipeline from .face_detection_pipeline import FaceDetectionPipeline from .face_recognition_pipeline import FaceRecognitionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline @@ -30,6 +31,7 @@ else: 'action_recognition_pipeline': ['ActionRecognitionPipeline'], 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], + 'object_detection_pipeline': ['ObjectDetectionPipeline'], 'face_detection_pipeline': ['FaceDetectionPipeline'], 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], 'face_recognition_pipeline': ['FaceRecognitionPipeline'], diff --git a/modelscope/pipelines/cv/object_detection_pipeline.py b/modelscope/pipelines/cv/object_detection_pipeline.py new file mode 100644 index 00000000..a604fb17 --- /dev/null +++ b/modelscope/pipelines/cv/object_detection_pipeline.py @@ -0,0 +1,51 @@ +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + + +@PIPELINES.register_module( + Tasks.human_detection, module_name=Pipelines.human_detection) +@PIPELINES.register_module( + Tasks.object_detection, module_name=Pipelines.object_detection) +class ObjectDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float) + img = self.model.preprocess(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = {'data': outputs} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + bboxes, scores, labels = self.model.postprocess(inputs['data']) + if bboxes is None: + return None + outputs = { + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: bboxes + } + + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index c4aace7e..1eab664c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -20,6 +20,7 @@ class CVTasks(object): image_classification = 'image-classification' image_tagging = 'image-tagging' object_detection = 'object-detection' + human_detection = 'human-detection' image_segmentation = 'image-segmentation' image_editing = 'image-editing' image_generation = 'image-generation' diff --git a/requirements/cv.txt b/requirements/cv.txt index 521ab34f..a0f505c0 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -1,4 +1,5 @@ decord>=0.6.0 easydict tf_slim +timm torchvision diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py new file mode 100644 index 00000000..8e4630a3 --- /dev/null +++ b/tests/pipelines/test_object_detection.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +class ObjectDetectionTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_object_detection(self): + input_location = 'data/test/images/image_detection.jpg' + model_id = 'damo/cv_vit_object-detection_coco' + object_detect = pipeline(Tasks.object_detection, model=model_id) + result = object_detect(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_object_detection_with_default_task(self): + input_location = 'data/test/images/image_detection.jpg' + object_detect = pipeline(Tasks.object_detection) + result = object_detect(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_human_detection(self): + input_location = 'data/test/images/image_detection.jpg' + model_id = 'damo/cv_resnet18_human-detection' + human_detect = pipeline(Tasks.human_detection, model=model_id) + result = human_detect(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_human_detection_with_default_task(self): + input_location = 'data/test/images/image_detection.jpg' + human_detect = pipeline(Tasks.human_detection) + result = human_detect(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main()