Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9490080master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:0218020651b6cdcc0051563f75750c8200d34fc49bf34cc053cd59c1f13cad03 | |||
| size 128624 | |||
| @@ -10,6 +10,7 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # vision models | |||
| detection = 'detection' | |||
| scrfd = 'scrfd' | |||
| classification_model = 'ClassificationModel' | |||
| nafnet = 'nafnet' | |||
| @@ -69,6 +70,8 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| human_detection = 'resnet18-human-detection' | |||
| object_detection = 'vit-object-detection' | |||
| image_classification = 'image-classification' | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| live_category = 'live-category' | |||
| @@ -3,4 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon, | |||
| cmdssl_video_embedding, face_detection, face_generation, | |||
| image_classification, image_color_enhance, image_colorization, | |||
| image_denoise, image_instance_segmentation, | |||
| image_to_image_translation, super_resolution, virual_tryon) | |||
| image_to_image_translation, object_detection, super_resolution, | |||
| virual_tryon) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .mmdet_model import DetectionModel | |||
| else: | |||
| _import_structure = { | |||
| 'mmdet_model': ['DetectionModel'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,92 @@ | |||
| import os.path as osp | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from .mmdet_ms.backbones import ViT | |||
| from .mmdet_ms.dense_heads import RPNNHead | |||
| from .mmdet_ms.necks import FPNF | |||
| from .mmdet_ms.roi_heads import FCNMaskNHead, Shared4Conv1FCBBoxNHead | |||
| @MODELS.register_module(Tasks.human_detection, module_name=Models.detection) | |||
| @MODELS.register_module(Tasks.object_detection, module_name=Models.detection) | |||
| class DetectionModel(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """str -- model file root.""" | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.datasets import replace_ImageToTensor | |||
| from mmdet.datasets.pipelines import Compose | |||
| from mmdet.models import build_detector | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| config_path = osp.join(model_dir, 'mmcv_config.py') | |||
| config = Config.from_file(config_path) | |||
| config.model.pretrained = None | |||
| self.model = build_detector(config.model) | |||
| checkpoint = load_checkpoint( | |||
| self.model, model_path, map_location='cpu') | |||
| self.class_names = checkpoint['meta']['CLASSES'] | |||
| config.test_pipeline[0].type = 'LoadImageFromWebcam' | |||
| self.test_pipeline = Compose( | |||
| replace_ImageToTensor(config.test_pipeline)) | |||
| self.model.cfg = config | |||
| self.model.eval() | |||
| self.score_thr = config.score_thr | |||
| def inference(self, data): | |||
| """data is dict,contain img and img_metas,follow with mmdet.""" | |||
| with torch.no_grad(): | |||
| results = self.model(return_loss=False, rescale=True, **data) | |||
| return results | |||
| def preprocess(self, image): | |||
| """image is numpy return is dict contain img and img_metas,follow with mmdet.""" | |||
| from mmcv.parallel import collate, scatter | |||
| data = dict(img=image) | |||
| data = self.test_pipeline(data) | |||
| data = collate([data], samples_per_gpu=1) | |||
| data['img_metas'] = [ | |||
| img_metas.data[0] for img_metas in data['img_metas'] | |||
| ] | |||
| data['img'] = [img.data[0] for img in data['img']] | |||
| if next(self.model.parameters()).is_cuda: | |||
| data = scatter(data, [next(self.model.parameters()).device])[0] | |||
| return data | |||
| def postprocess(self, inputs): | |||
| if isinstance(inputs[0], tuple): | |||
| bbox_result, _ = inputs[0] | |||
| else: | |||
| bbox_result, _ = inputs[0], None | |||
| labels = [ | |||
| np.full(bbox.shape[0], i, dtype=np.int32) | |||
| for i, bbox in enumerate(bbox_result) | |||
| ] | |||
| labels = np.concatenate(labels) | |||
| bbox_result = np.vstack(bbox_result) | |||
| scores = bbox_result[:, -1] | |||
| inds = scores > self.score_thr | |||
| if np.sum(np.array(inds).astype('int')) == 0: | |||
| return None, None, None | |||
| bboxes = bbox_result[inds, :] | |||
| labels = labels[inds] | |||
| scores = bboxes[:, 4] | |||
| bboxes = bboxes[:, 0:4] | |||
| labels = [self.class_names[i_label] for i_label in labels] | |||
| return bboxes, scores, labels | |||
| @@ -0,0 +1,4 @@ | |||
| from .backbones import ViT | |||
| from .dense_heads import AnchorNHead, RPNNHead | |||
| from .necks import FPNF | |||
| from .utils import ConvModule_Norm, load_checkpoint | |||
| @@ -0,0 +1,3 @@ | |||
| from .vit import ViT | |||
| __all__ = ['ViT'] | |||
| @@ -0,0 +1,626 @@ | |||
| # -------------------------------------------------------- | |||
| # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |||
| # Github source: https://github.com/microsoft/unilm/tree/master/beit | |||
| # Copyright (c) 2021 Microsoft | |||
| # Licensed under The MIT License [see LICENSE for details] | |||
| # By Hangbo Bao | |||
| # Based on timm, mmseg, setr, xcit and swin code bases | |||
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |||
| # https://github.com/fudan-zvg/SETR | |||
| # https://github.com/facebookresearch/xcit/ | |||
| # https://github.com/microsoft/Swin-Transformer | |||
| # --------------------------------------------------------' | |||
| import math | |||
| from functools import partial | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as checkpoint | |||
| from mmdet.models.builder import BACKBONES | |||
| from mmdet.utils import get_root_logger | |||
| from timm.models.layers import drop_path, to_2tuple, trunc_normal_ | |||
| from ..utils import load_checkpoint | |||
| class DropPath(nn.Module): | |||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |||
| """ | |||
| def __init__(self, drop_prob=None): | |||
| super(DropPath, self).__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| def extra_repr(self): | |||
| return 'p={}'.format(self.drop_prob) | |||
| class Mlp(nn.Module): | |||
| def __init__(self, | |||
| in_features, | |||
| hidden_features=None, | |||
| out_features=None, | |||
| act_layer=nn.GELU, | |||
| drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||
| self.drop = nn.Dropout(drop) | |||
| def forward(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| class Attention(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=8, | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| attn_drop=0., | |||
| proj_drop=0., | |||
| window_size=None, | |||
| attn_head_dim=None): | |||
| super().__init__() | |||
| self.num_heads = num_heads | |||
| head_dim = dim // num_heads | |||
| if attn_head_dim is not None: | |||
| head_dim = attn_head_dim | |||
| all_head_dim = head_dim * self.num_heads | |||
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |||
| self.scale = qk_scale or head_dim**-0.5 | |||
| self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) | |||
| self.window_size = window_size | |||
| q_size = window_size[0] | |||
| rel_sp_dim = 2 * q_size - 1 | |||
| self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |||
| self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |||
| self.attn_drop = nn.Dropout(attn_drop) | |||
| self.proj = nn.Linear(all_head_dim, dim) | |||
| self.proj_drop = nn.Dropout(proj_drop) | |||
| def forward(self, x, H, W, rel_pos_bias=None): | |||
| B, N, C = x.shape | |||
| qkv = self.qkv(x) | |||
| qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |||
| q, k, v = qkv[0], qkv[1], qkv[ | |||
| 2] # make torchscript happy (cannot use tensor as tuple) | |||
| q = q * self.scale | |||
| attn = (q @ k.transpose(-2, -1)) | |||
| attn = calc_rel_pos_spatial(attn, q, self.window_size, | |||
| self.window_size, self.rel_pos_h, | |||
| self.rel_pos_w) | |||
| attn = attn.softmax(dim=-1) | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| def window_partition(x, window_size): | |||
| """ | |||
| Args: | |||
| x: (B, H, W, C) | |||
| window_size (int): window size | |||
| Returns: | |||
| windows: (num_windows*B, window_size, window_size, C) | |||
| """ | |||
| B, H, W, C = x.shape | |||
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||
| C) | |||
| windows = x.permute(0, 1, 3, 2, 4, | |||
| 5).contiguous().view(-1, window_size, window_size, C) | |||
| return windows | |||
| def window_reverse(windows, window_size, H, W): | |||
| """ | |||
| Args: | |||
| windows: (num_windows*B, window_size, window_size, C) | |||
| window_size (int): Window size | |||
| H (int): Height of image | |||
| W (int): Width of image | |||
| Returns: | |||
| x: (B, H, W, C) | |||
| """ | |||
| B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||
| x = windows.view(B, H // window_size, W // window_size, window_size, | |||
| window_size, -1) | |||
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||
| return x | |||
| def calc_rel_pos_spatial( | |||
| attn, | |||
| q, | |||
| q_shape, | |||
| k_shape, | |||
| rel_pos_h, | |||
| rel_pos_w, | |||
| ): | |||
| """ | |||
| Spatial Relative Positional Embeddings. | |||
| """ | |||
| sp_idx = 0 | |||
| q_h, q_w = q_shape | |||
| k_h, k_w = k_shape | |||
| # Scale up rel pos if shapes for q and k are different. | |||
| q_h_ratio = max(k_h / q_h, 1.0) | |||
| k_h_ratio = max(q_h / k_h, 1.0) | |||
| dist_h = ( | |||
| torch.arange(q_h)[:, None] * q_h_ratio | |||
| - torch.arange(k_h)[None, :] * k_h_ratio) | |||
| dist_h += (k_h - 1) * k_h_ratio | |||
| q_w_ratio = max(k_w / q_w, 1.0) | |||
| k_w_ratio = max(q_w / k_w, 1.0) | |||
| dist_w = ( | |||
| torch.arange(q_w)[:, None] * q_w_ratio | |||
| - torch.arange(k_w)[None, :] * k_w_ratio) | |||
| dist_w += (k_w - 1) * k_w_ratio | |||
| Rh = rel_pos_h[dist_h.long()] | |||
| Rw = rel_pos_w[dist_w.long()] | |||
| B, n_head, q_N, dim = q.shape | |||
| r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) | |||
| rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) | |||
| rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) | |||
| attn[:, :, sp_idx:, sp_idx:] = ( | |||
| attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) | |||
| + rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view( | |||
| B, -1, q_h * q_w, k_h * k_w) | |||
| return attn | |||
| class WindowAttention(nn.Module): | |||
| """ Window based multi-head self attention (W-MSA) module with relative position bias. | |||
| It supports both of shifted and non-shifted window. | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| window_size (tuple[int]): The height and width of the window. | |||
| num_heads (int): Number of attention heads. | |||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||
| """ | |||
| def __init__(self, | |||
| dim, | |||
| window_size, | |||
| num_heads, | |||
| qkv_bias=True, | |||
| qk_scale=None, | |||
| attn_drop=0., | |||
| proj_drop=0., | |||
| attn_head_dim=None): | |||
| super().__init__() | |||
| self.dim = dim | |||
| self.window_size = window_size # Wh, Ww | |||
| self.num_heads = num_heads | |||
| head_dim = dim // num_heads | |||
| self.scale = qk_scale or head_dim**-0.5 | |||
| q_size = window_size[0] | |||
| rel_sp_dim = 2 * q_size - 1 | |||
| self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |||
| self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |||
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||
| self.attn_drop = nn.Dropout(attn_drop) | |||
| self.proj = nn.Linear(dim, dim) | |||
| self.proj_drop = nn.Dropout(proj_drop) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| def forward(self, x, H, W): | |||
| """ Forward function. | |||
| Args: | |||
| x: input features with shape of (num_windows*B, N, C) | |||
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||
| """ | |||
| B_, N, C = x.shape | |||
| x = x.reshape(B_, H, W, C) | |||
| pad_l = pad_t = 0 | |||
| pad_r = (self.window_size[1] | |||
| - W % self.window_size[1]) % self.window_size[1] | |||
| pad_b = (self.window_size[0] | |||
| - H % self.window_size[0]) % self.window_size[0] | |||
| x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||
| _, Hp, Wp, _ = x.shape | |||
| x = window_partition( | |||
| x, self.window_size[0]) # nW*B, window_size, window_size, C | |||
| x = x.view(-1, self.window_size[1] * self.window_size[0], | |||
| C) # nW*B, window_size*window_size, C | |||
| B_w = x.shape[0] | |||
| N_w = x.shape[1] | |||
| qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads, | |||
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
| q, k, v = qkv[0], qkv[1], qkv[ | |||
| 2] # make torchscript happy (cannot use tensor as tuple) | |||
| q = q * self.scale | |||
| attn = (q @ k.transpose(-2, -1)) | |||
| attn = calc_rel_pos_spatial(attn, q, self.window_size, | |||
| self.window_size, self.rel_pos_h, | |||
| self.rel_pos_w) | |||
| attn = self.softmax(attn) | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| x = x.view(-1, self.window_size[1], self.window_size[0], C) | |||
| x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C | |||
| if pad_r > 0 or pad_b > 0: | |||
| x = x[:, :H, :W, :].contiguous() | |||
| x = x.view(B_, H * W, C) | |||
| return x | |||
| class Block(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads, | |||
| mlp_ratio=4., | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| drop=0., | |||
| attn_drop=0., | |||
| drop_path=0., | |||
| init_values=None, | |||
| act_layer=nn.GELU, | |||
| norm_layer=nn.LayerNorm, | |||
| window_size=None, | |||
| attn_head_dim=None, | |||
| window=False): | |||
| super().__init__() | |||
| self.norm1 = norm_layer(dim) | |||
| if not window: | |||
| self.attn = Attention( | |||
| dim, | |||
| num_heads=num_heads, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| attn_drop=attn_drop, | |||
| proj_drop=drop, | |||
| window_size=window_size, | |||
| attn_head_dim=attn_head_dim) | |||
| else: | |||
| self.attn = WindowAttention( | |||
| dim, | |||
| num_heads=num_heads, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| attn_drop=attn_drop, | |||
| proj_drop=drop, | |||
| window_size=window_size, | |||
| attn_head_dim=attn_head_dim) | |||
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| self.norm2 = norm_layer(dim) | |||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||
| self.mlp = Mlp( | |||
| in_features=dim, | |||
| hidden_features=mlp_hidden_dim, | |||
| act_layer=act_layer, | |||
| drop=drop) | |||
| if init_values is not None: | |||
| self.gamma_1 = nn.Parameter( | |||
| init_values * torch.ones((dim)), requires_grad=True) | |||
| self.gamma_2 = nn.Parameter( | |||
| init_values * torch.ones((dim)), requires_grad=True) | |||
| else: | |||
| self.gamma_1, self.gamma_2 = None, None | |||
| def forward(self, x, H, W): | |||
| if self.gamma_1 is None: | |||
| x = x + self.drop_path(self.attn(self.norm1(x), H, W)) | |||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
| else: | |||
| x = x + self.drop_path( | |||
| self.gamma_1 * self.attn(self.norm1(x), H, W)) | |||
| x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) | |||
| return x | |||
| class PatchEmbed(nn.Module): | |||
| """ Image to Patch Embedding | |||
| """ | |||
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |||
| super().__init__() | |||
| img_size = to_2tuple(img_size) | |||
| patch_size = to_2tuple(patch_size) | |||
| num_patches = (img_size[1] // patch_size[1]) * ( | |||
| img_size[0] // patch_size[0]) | |||
| self.patch_shape = (img_size[0] // patch_size[0], | |||
| img_size[1] // patch_size[1]) | |||
| self.img_size = img_size | |||
| self.patch_size = patch_size | |||
| self.num_patches = num_patches | |||
| self.proj = nn.Conv2d( | |||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| def forward(self, x, **kwargs): | |||
| B, C, H, W = x.shape | |||
| # FIXME look at relaxing size constraints | |||
| # assert H == self.img_size[0] and W == self.img_size[1], \ | |||
| # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |||
| x = self.proj(x) | |||
| Hp, Wp = x.shape[2], x.shape[3] | |||
| x = x.flatten(2).transpose(1, 2) | |||
| return x, (Hp, Wp) | |||
| class HybridEmbed(nn.Module): | |||
| """ CNN Feature Map Embedding | |||
| Extract feature map from CNN, flatten, project to embedding dim. | |||
| """ | |||
| def __init__(self, | |||
| backbone, | |||
| img_size=224, | |||
| feature_size=None, | |||
| in_chans=3, | |||
| embed_dim=768): | |||
| super().__init__() | |||
| assert isinstance(backbone, nn.Module) | |||
| img_size = to_2tuple(img_size) | |||
| self.img_size = img_size | |||
| self.backbone = backbone | |||
| if feature_size is None: | |||
| with torch.no_grad(): | |||
| # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature | |||
| # map for all networks, the feature metadata has reliable channel and stride info, but using | |||
| # stride to calc feature dim requires info about padding of each stage that isn't captured. | |||
| training = backbone.training | |||
| if training: | |||
| backbone.eval() | |||
| o = self.backbone( | |||
| torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] | |||
| feature_size = o.shape[-2:] | |||
| feature_dim = o.shape[1] | |||
| backbone.train(training) | |||
| else: | |||
| feature_size = to_2tuple(feature_size) | |||
| feature_dim = self.backbone.feature_info.channels()[-1] | |||
| self.num_patches = feature_size[0] * feature_size[1] | |||
| self.proj = nn.Linear(feature_dim, embed_dim) | |||
| def forward(self, x): | |||
| x = self.backbone(x)[-1] | |||
| x = x.flatten(2).transpose(1, 2) | |||
| x = self.proj(x) | |||
| return x | |||
| class Norm2d(nn.Module): | |||
| def __init__(self, embed_dim): | |||
| super().__init__() | |||
| self.ln = nn.LayerNorm(embed_dim, eps=1e-6) | |||
| def forward(self, x): | |||
| x = x.permute(0, 2, 3, 1) | |||
| x = self.ln(x) | |||
| x = x.permute(0, 3, 1, 2).contiguous() | |||
| return x | |||
| @BACKBONES.register_module() | |||
| class ViT(nn.Module): | |||
| """ Vision Transformer with support for patch or hybrid CNN input stage | |||
| """ | |||
| def __init__(self, | |||
| img_size=224, | |||
| patch_size=16, | |||
| in_chans=3, | |||
| num_classes=80, | |||
| embed_dim=768, | |||
| depth=12, | |||
| num_heads=12, | |||
| mlp_ratio=4., | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| drop_rate=0., | |||
| attn_drop_rate=0., | |||
| drop_path_rate=0., | |||
| hybrid_backbone=None, | |||
| norm_layer=None, | |||
| init_values=None, | |||
| use_checkpoint=False, | |||
| use_abs_pos_emb=False, | |||
| use_rel_pos_bias=False, | |||
| use_shared_rel_pos_bias=False, | |||
| out_indices=[11], | |||
| interval=3, | |||
| pretrained=None): | |||
| super().__init__() | |||
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||
| self.num_classes = num_classes | |||
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||
| if hybrid_backbone is not None: | |||
| self.patch_embed = HybridEmbed( | |||
| hybrid_backbone, | |||
| img_size=img_size, | |||
| in_chans=in_chans, | |||
| embed_dim=embed_dim) | |||
| else: | |||
| self.patch_embed = PatchEmbed( | |||
| img_size=img_size, | |||
| patch_size=patch_size, | |||
| in_chans=in_chans, | |||
| embed_dim=embed_dim) | |||
| num_patches = self.patch_embed.num_patches | |||
| self.out_indices = out_indices | |||
| if use_abs_pos_emb: | |||
| self.pos_embed = nn.Parameter( | |||
| torch.zeros(1, num_patches, embed_dim)) | |||
| else: | |||
| self.pos_embed = None | |||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||
| ] # stochastic depth decay rule | |||
| self.use_rel_pos_bias = use_rel_pos_bias | |||
| self.use_checkpoint = use_checkpoint | |||
| self.blocks = nn.ModuleList([ | |||
| Block( | |||
| dim=embed_dim, | |||
| num_heads=num_heads, | |||
| mlp_ratio=mlp_ratio, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| drop=drop_rate, | |||
| attn_drop=attn_drop_rate, | |||
| drop_path=dpr[i], | |||
| norm_layer=norm_layer, | |||
| init_values=init_values, | |||
| window_size=(14, 14) if | |||
| ((i + 1) % interval != 0) else self.patch_embed.patch_shape, | |||
| window=((i + 1) % interval != 0)) for i in range(depth) | |||
| ]) | |||
| if self.pos_embed is not None: | |||
| trunc_normal_(self.pos_embed, std=.02) | |||
| self.norm = norm_layer(embed_dim) | |||
| self.fpn1 = nn.Sequential( | |||
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |||
| Norm2d(embed_dim), | |||
| nn.GELU(), | |||
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), | |||
| ) | |||
| self.fpn2 = nn.Sequential( | |||
| nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, | |||
| stride=2), ) | |||
| self.fpn3 = nn.Identity() | |||
| self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.apply(self._init_weights) | |||
| self.fix_init_weight() | |||
| self.pretrained = pretrained | |||
| def fix_init_weight(self): | |||
| def rescale(param, layer_id): | |||
| param.div_(math.sqrt(2.0 * layer_id)) | |||
| for layer_id, layer in enumerate(self.blocks): | |||
| rescale(layer.attn.proj.weight.data, layer_id + 1) | |||
| rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |||
| def _init_weights(self, m): | |||
| if isinstance(m, nn.Linear): | |||
| trunc_normal_(m.weight, std=.02) | |||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| elif isinstance(m, nn.LayerNorm): | |||
| nn.init.constant_(m.bias, 0) | |||
| nn.init.constant_(m.weight, 1.0) | |||
| def init_weights(self, pretrained=None): | |||
| """Initialize the weights in backbone. | |||
| Args: | |||
| pretrained (str, optional): Path to pre-trained weights. | |||
| Defaults to None. | |||
| """ | |||
| pretrained = pretrained or self.pretrained | |||
| def _init_weights(m): | |||
| if isinstance(m, nn.Linear): | |||
| trunc_normal_(m.weight, std=.02) | |||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| elif isinstance(m, nn.LayerNorm): | |||
| nn.init.constant_(m.bias, 0) | |||
| nn.init.constant_(m.weight, 1.0) | |||
| if isinstance(pretrained, str): | |||
| self.apply(_init_weights) | |||
| logger = get_root_logger() | |||
| print(f'load from {pretrained}') | |||
| load_checkpoint(self, pretrained, strict=False, logger=logger) | |||
| elif pretrained is None: | |||
| self.apply(_init_weights) | |||
| else: | |||
| raise TypeError('pretrained must be a str or None') | |||
| def get_num_layers(self): | |||
| return len(self.blocks) | |||
| @torch.jit.ignore | |||
| def no_weight_decay(self): | |||
| return {'pos_embed', 'cls_token'} | |||
| def forward_features(self, x): | |||
| B, C, H, W = x.shape | |||
| x, (Hp, Wp) = self.patch_embed(x) | |||
| batch_size, seq_len, _ = x.size() | |||
| if self.pos_embed is not None: | |||
| x = x + self.pos_embed | |||
| x = self.pos_drop(x) | |||
| features = [] | |||
| for i, blk in enumerate(self.blocks): | |||
| if self.use_checkpoint: | |||
| x = checkpoint.checkpoint(blk, x) | |||
| else: | |||
| x = blk(x, Hp, Wp) | |||
| x = self.norm(x) | |||
| xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp) | |||
| ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] | |||
| for i in range(len(ops)): | |||
| features.append(ops[i](xp)) | |||
| return tuple(features) | |||
| def forward(self, x): | |||
| x = self.forward_features(x) | |||
| return x | |||
| @@ -0,0 +1,4 @@ | |||
| from .anchor_head import AnchorNHead | |||
| from .rpn_head import RPNNHead | |||
| __all__ = ['AnchorNHead', 'RPNNHead'] | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| from mmdet.models.builder import HEADS | |||
| from mmdet.models.dense_heads import AnchorHead | |||
| @HEADS.register_module() | |||
| class AnchorNHead(AnchorHead): | |||
| """Anchor-based head (RPN, RetinaNet, SSD, etc.). | |||
| Args: | |||
| num_classes (int): Number of categories excluding the background | |||
| category. | |||
| in_channels (int): Number of channels in the input feature map. | |||
| feat_channels (int): Number of hidden channels. Used in child classes. | |||
| anchor_generator (dict): Config dict for anchor generator | |||
| bbox_coder (dict): Config of bounding box coder. | |||
| reg_decoded_bbox (bool): If true, the regression loss would be | |||
| applied directly on decoded bounding boxes, converting both | |||
| the predicted boxes and regression targets to absolute | |||
| coordinates format. Default False. It should be `True` when | |||
| using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. | |||
| loss_cls (dict): Config of classification loss. | |||
| loss_bbox (dict): Config of localization loss. | |||
| train_cfg (dict): Training config of anchor head. | |||
| test_cfg (dict): Testing config of anchor head. | |||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||
| """ # noqa: W605 | |||
| def __init__(self, | |||
| num_classes, | |||
| in_channels, | |||
| feat_channels, | |||
| anchor_generator=None, | |||
| bbox_coder=None, | |||
| reg_decoded_bbox=False, | |||
| loss_cls=None, | |||
| loss_bbox=None, | |||
| train_cfg=None, | |||
| test_cfg=None, | |||
| norm_cfg=None, | |||
| init_cfg=None): | |||
| self.norm_cfg = norm_cfg | |||
| super(AnchorNHead, | |||
| self).__init__(num_classes, in_channels, feat_channels, | |||
| anchor_generator, bbox_coder, reg_decoded_bbox, | |||
| loss_cls, loss_bbox, train_cfg, test_cfg, | |||
| init_cfg) | |||
| @@ -0,0 +1,268 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| import copy | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmcv.ops import batched_nms | |||
| from mmdet.models.builder import HEADS | |||
| from ..utils import ConvModule_Norm | |||
| from .anchor_head import AnchorNHead | |||
| @HEADS.register_module() | |||
| class RPNNHead(AnchorNHead): | |||
| """RPN head. | |||
| Args: | |||
| in_channels (int): Number of channels in the input feature map. | |||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||
| num_convs (int): Number of convolution layers in the head. Default 1. | |||
| """ # noqa: W605 | |||
| def __init__(self, | |||
| in_channels, | |||
| init_cfg=dict(type='Normal', layer='Conv2d', std=0.01), | |||
| num_convs=1, | |||
| **kwargs): | |||
| self.num_convs = num_convs | |||
| super(RPNNHead, self).__init__( | |||
| 1, in_channels, init_cfg=init_cfg, **kwargs) | |||
| def _init_layers(self): | |||
| """Initialize layers of the head.""" | |||
| if self.num_convs > 1: | |||
| rpn_convs = [] | |||
| for i in range(self.num_convs): | |||
| if i == 0: | |||
| in_channels = self.in_channels | |||
| else: | |||
| in_channels = self.feat_channels | |||
| # use ``inplace=False`` to avoid error: one of the variables | |||
| # needed for gradient computation has been modified by an | |||
| # inplace operation. | |||
| rpn_convs.append( | |||
| ConvModule_Norm( | |||
| in_channels, | |||
| self.feat_channels, | |||
| 3, | |||
| padding=1, | |||
| norm_cfg=self.norm_cfg, | |||
| inplace=False)) | |||
| self.rpn_conv = nn.Sequential(*rpn_convs) | |||
| else: | |||
| self.rpn_conv = nn.Conv2d( | |||
| self.in_channels, self.feat_channels, 3, padding=1) | |||
| self.rpn_cls = nn.Conv2d(self.feat_channels, | |||
| self.num_base_priors * self.cls_out_channels, | |||
| 1) | |||
| self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4, | |||
| 1) | |||
| def forward_single(self, x): | |||
| """Forward feature map of a single scale level.""" | |||
| x = self.rpn_conv(x) | |||
| x = F.relu(x, inplace=True) | |||
| rpn_cls_score = self.rpn_cls(x) | |||
| rpn_bbox_pred = self.rpn_reg(x) | |||
| return rpn_cls_score, rpn_bbox_pred | |||
| def loss(self, | |||
| cls_scores, | |||
| bbox_preds, | |||
| gt_bboxes, | |||
| img_metas, | |||
| gt_bboxes_ignore=None): | |||
| """Compute losses of the head. | |||
| Args: | |||
| cls_scores (list[Tensor]): Box scores for each scale level | |||
| Has shape (N, num_anchors * num_classes, H, W) | |||
| bbox_preds (list[Tensor]): Box energies / deltas for each scale | |||
| level with shape (N, num_anchors * 4, H, W) | |||
| gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |||
| shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |||
| img_metas (list[dict]): Meta information of each image, e.g., | |||
| image size, scaling factor, etc. | |||
| gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |||
| boxes can be ignored when computing the loss. | |||
| Returns: | |||
| dict[str, Tensor]: A dictionary of loss components. | |||
| """ | |||
| losses = super(RPNNHead, self).loss( | |||
| cls_scores, | |||
| bbox_preds, | |||
| gt_bboxes, | |||
| None, | |||
| img_metas, | |||
| gt_bboxes_ignore=gt_bboxes_ignore) | |||
| return dict( | |||
| loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) | |||
| def _get_bboxes_single(self, | |||
| cls_score_list, | |||
| bbox_pred_list, | |||
| score_factor_list, | |||
| mlvl_anchors, | |||
| img_meta, | |||
| cfg, | |||
| rescale=False, | |||
| with_nms=True, | |||
| **kwargs): | |||
| """Transform outputs of a single image into bbox predictions. | |||
| Args: | |||
| cls_score_list (list[Tensor]): Box scores from all scale | |||
| levels of a single image, each item has shape | |||
| (num_anchors * num_classes, H, W). | |||
| bbox_pred_list (list[Tensor]): Box energies / deltas from | |||
| all scale levels of a single image, each item has | |||
| shape (num_anchors * 4, H, W). | |||
| score_factor_list (list[Tensor]): Score factor from all scale | |||
| levels of a single image. RPN head does not need this value. | |||
| mlvl_anchors (list[Tensor]): Anchors of all scale level | |||
| each item has shape (num_anchors, 4). | |||
| img_meta (dict): Image meta info. | |||
| cfg (mmcv.Config): Test / postprocessing configuration, | |||
| if None, test_cfg would be used. | |||
| rescale (bool): If True, return boxes in original image space. | |||
| Default: False. | |||
| with_nms (bool): If True, do nms before return boxes. | |||
| Default: True. | |||
| Returns: | |||
| Tensor: Labeled boxes in shape (n, 5), where the first 4 columns | |||
| are bounding box positions (tl_x, tl_y, br_x, br_y) and the | |||
| 5-th column is a score between 0 and 1. | |||
| """ | |||
| cfg = self.test_cfg if cfg is None else cfg | |||
| cfg = copy.deepcopy(cfg) | |||
| img_shape = img_meta['img_shape'] | |||
| # bboxes from different level should be independent during NMS, | |||
| # level_ids are used as labels for batched NMS to separate them | |||
| level_ids = [] | |||
| mlvl_scores = [] | |||
| mlvl_bbox_preds = [] | |||
| mlvl_valid_anchors = [] | |||
| nms_pre = cfg.get('nms_pre', -1) | |||
| for level_idx in range(len(cls_score_list)): | |||
| rpn_cls_score = cls_score_list[level_idx] | |||
| rpn_bbox_pred = bbox_pred_list[level_idx] | |||
| assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] | |||
| rpn_cls_score = rpn_cls_score.permute(1, 2, 0) | |||
| if self.use_sigmoid_cls: | |||
| rpn_cls_score = rpn_cls_score.reshape(-1) | |||
| scores = rpn_cls_score.sigmoid() | |||
| else: | |||
| rpn_cls_score = rpn_cls_score.reshape(-1, 2) | |||
| # We set FG labels to [0, num_class-1] and BG label to | |||
| # num_class in RPN head since mmdet v2.5, which is unified to | |||
| # be consistent with other head since mmdet v2.0. In mmdet v2.0 | |||
| # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. | |||
| scores = rpn_cls_score.softmax(dim=1)[:, 0] | |||
| rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) | |||
| anchors = mlvl_anchors[level_idx] | |||
| if 0 < nms_pre < scores.shape[0]: | |||
| # sort is faster than topk | |||
| # _, topk_inds = scores.topk(cfg.nms_pre) | |||
| ranked_scores, rank_inds = scores.sort(descending=True) | |||
| topk_inds = rank_inds[:nms_pre] | |||
| scores = ranked_scores[:nms_pre] | |||
| rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] | |||
| anchors = anchors[topk_inds, :] | |||
| mlvl_scores.append(scores) | |||
| mlvl_bbox_preds.append(rpn_bbox_pred) | |||
| mlvl_valid_anchors.append(anchors) | |||
| level_ids.append( | |||
| scores.new_full((scores.size(0), ), | |||
| level_idx, | |||
| dtype=torch.long)) | |||
| return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds, | |||
| mlvl_valid_anchors, level_ids, cfg, | |||
| img_shape) | |||
| def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors, | |||
| level_ids, cfg, img_shape, **kwargs): | |||
| """bbox post-processing method. | |||
| The boxes would be rescaled to the original image scale and do | |||
| the nms operation. Usually with_nms is False is used for aug test. | |||
| Args: | |||
| mlvl_scores (list[Tensor]): Box scores from all scale | |||
| levels of a single image, each item has shape | |||
| (num_bboxes, num_class). | |||
| mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale | |||
| levels of a single image, each item has shape (num_bboxes, 4). | |||
| mlvl_valid_anchors (list[Tensor]): Anchors of all scale level | |||
| each item has shape (num_bboxes, 4). | |||
| level_ids (list[Tensor]): Indexes from all scale levels of a | |||
| single image, each item has shape (num_bboxes, ). | |||
| cfg (mmcv.Config): Test / postprocessing configuration, | |||
| if None, test_cfg would be used. | |||
| img_shape (tuple(int)): Shape of current image. | |||
| Returns: | |||
| Tensor: Labeled boxes in shape (n, 5), where the first 4 columns | |||
| are bounding box positions (tl_x, tl_y, br_x, br_y) and the | |||
| 5-th column is a score between 0 and 1. | |||
| """ | |||
| scores = torch.cat(mlvl_scores) | |||
| anchors = torch.cat(mlvl_valid_anchors) | |||
| rpn_bbox_pred = torch.cat(mlvl_bboxes) | |||
| proposals = self.bbox_coder.decode( | |||
| anchors, rpn_bbox_pred, max_shape=img_shape) | |||
| ids = torch.cat(level_ids) | |||
| if cfg.min_bbox_size >= 0: | |||
| w = proposals[:, 2] - proposals[:, 0] | |||
| h = proposals[:, 3] - proposals[:, 1] | |||
| valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) | |||
| if not valid_mask.all(): | |||
| proposals = proposals[valid_mask] | |||
| scores = scores[valid_mask] | |||
| ids = ids[valid_mask] | |||
| if proposals.numel() > 0: | |||
| dets, _ = batched_nms(proposals, scores, ids, cfg.nms) | |||
| else: | |||
| return proposals.new_zeros(0, 5) | |||
| return dets[:cfg.max_per_img] | |||
| def onnx_export(self, x, img_metas): | |||
| """Test without augmentation. | |||
| Args: | |||
| x (tuple[Tensor]): Features from the upstream network, each is | |||
| a 4D-tensor. | |||
| img_metas (list[dict]): Meta info of each image. | |||
| Returns: | |||
| Tensor: dets of shape [N, num_det, 5]. | |||
| """ | |||
| cls_scores, bbox_preds = self(x) | |||
| assert len(cls_scores) == len(bbox_preds) | |||
| batch_bboxes, batch_scores = super(RPNNHead, self).onnx_export( | |||
| cls_scores, bbox_preds, img_metas=img_metas, with_nms=False) | |||
| # Use ONNX::NonMaxSuppression in deployment | |||
| from mmdet.core.export import add_dummy_nms_for_onnx | |||
| cfg = copy.deepcopy(self.test_cfg) | |||
| score_threshold = cfg.nms.get('score_thr', 0.0) | |||
| nms_pre = cfg.get('deploy_nms_pre', -1) | |||
| # Different from the normal forward doing NMS level by level, | |||
| # we do NMS across all levels when exporting ONNX. | |||
| dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores, | |||
| cfg.max_per_img, | |||
| cfg.nms.iou_threshold, | |||
| score_threshold, nms_pre, | |||
| cfg.max_per_img) | |||
| return dets | |||
| @@ -0,0 +1,3 @@ | |||
| from .fpn import FPNF | |||
| __all__ = ['FPNF'] | |||
| @@ -0,0 +1,207 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmcv.runner import BaseModule, auto_fp16 | |||
| from mmdet.models.builder import NECKS | |||
| from ..utils import ConvModule_Norm | |||
| @NECKS.register_module() | |||
| class FPNF(BaseModule): | |||
| r"""Feature Pyramid Network. | |||
| This is an implementation of paper `Feature Pyramid Networks for Object | |||
| Detection <https://arxiv.org/abs/1612.03144>`_. | |||
| Args: | |||
| in_channels (List[int]): Number of input channels per scale. | |||
| out_channels (int): Number of output channels (used at each scale) | |||
| num_outs (int): Number of output scales. | |||
| start_level (int): Index of the start input backbone level used to | |||
| build the feature pyramid. Default: 0. | |||
| end_level (int): Index of the end input backbone level (exclusive) to | |||
| build the feature pyramid. Default: -1, which means the last level. | |||
| add_extra_convs (bool | str): If bool, it decides whether to add conv | |||
| layers on top of the original feature maps. Default to False. | |||
| If True, it is equivalent to `add_extra_convs='on_input'`. | |||
| If str, it specifies the source feature map of the extra convs. | |||
| Only the following options are allowed | |||
| - 'on_input': Last feat map of neck inputs (i.e. backbone feature). | |||
| - 'on_lateral': Last feature map after lateral convs. | |||
| - 'on_output': The last output feature map after fpn convs. | |||
| relu_before_extra_convs (bool): Whether to apply relu before the extra | |||
| conv. Default: False. | |||
| no_norm_on_lateral (bool): Whether to apply norm on lateral. | |||
| Default: False. | |||
| conv_cfg (dict): Config dict for convolution layer. Default: None. | |||
| norm_cfg (dict): Config dict for normalization layer. Default: None. | |||
| act_cfg (str): Config dict for activation layer in ConvModule. | |||
| Default: None. | |||
| upsample_cfg (dict): Config dict for interpolate layer. | |||
| Default: `dict(mode='nearest')` | |||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||
| Example: | |||
| >>> import torch | |||
| >>> in_channels = [2, 3, 5, 7] | |||
| >>> scales = [340, 170, 84, 43] | |||
| >>> inputs = [torch.rand(1, c, s, s) | |||
| ... for c, s in zip(in_channels, scales)] | |||
| >>> self = FPN(in_channels, 11, len(in_channels)).eval() | |||
| >>> outputs = self.forward(inputs) | |||
| >>> for i in range(len(outputs)): | |||
| ... print(f'outputs[{i}].shape = {outputs[i].shape}') | |||
| outputs[0].shape = torch.Size([1, 11, 340, 340]) | |||
| outputs[1].shape = torch.Size([1, 11, 170, 170]) | |||
| outputs[2].shape = torch.Size([1, 11, 84, 84]) | |||
| outputs[3].shape = torch.Size([1, 11, 43, 43]) | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| num_outs, | |||
| start_level=0, | |||
| end_level=-1, | |||
| add_extra_convs=False, | |||
| relu_before_extra_convs=False, | |||
| no_norm_on_lateral=False, | |||
| conv_cfg=None, | |||
| norm_cfg=None, | |||
| act_cfg=None, | |||
| use_residual=True, | |||
| upsample_cfg=dict(mode='nearest'), | |||
| init_cfg=dict( | |||
| type='Xavier', layer='Conv2d', distribution='uniform')): | |||
| super(FPNF, self).__init__(init_cfg) | |||
| assert isinstance(in_channels, list) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.num_ins = len(in_channels) | |||
| self.num_outs = num_outs | |||
| self.relu_before_extra_convs = relu_before_extra_convs | |||
| self.no_norm_on_lateral = no_norm_on_lateral | |||
| self.fp16_enabled = False | |||
| self.upsample_cfg = upsample_cfg.copy() | |||
| self.use_residual = use_residual | |||
| if end_level == -1: | |||
| self.backbone_end_level = self.num_ins | |||
| assert num_outs >= self.num_ins - start_level | |||
| else: | |||
| # if end_level < inputs, no extra level is allowed | |||
| self.backbone_end_level = end_level | |||
| assert end_level <= len(in_channels) | |||
| assert num_outs == end_level - start_level | |||
| self.start_level = start_level | |||
| self.end_level = end_level | |||
| self.add_extra_convs = add_extra_convs | |||
| assert isinstance(add_extra_convs, (str, bool)) | |||
| if isinstance(add_extra_convs, str): | |||
| # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' | |||
| assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') | |||
| elif add_extra_convs: # True | |||
| self.add_extra_convs = 'on_input' | |||
| self.lateral_convs = nn.ModuleList() | |||
| self.fpn_convs = nn.ModuleList() | |||
| for i in range(self.start_level, self.backbone_end_level): | |||
| l_conv = ConvModule_Norm( | |||
| in_channels[i], | |||
| out_channels, | |||
| 1, | |||
| conv_cfg=conv_cfg, | |||
| norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, | |||
| act_cfg=act_cfg, | |||
| inplace=False) | |||
| fpn_conv = ConvModule_Norm( | |||
| out_channels, | |||
| out_channels, | |||
| 3, | |||
| padding=1, | |||
| conv_cfg=conv_cfg, | |||
| norm_cfg=norm_cfg, | |||
| act_cfg=act_cfg, | |||
| inplace=False) | |||
| self.lateral_convs.append(l_conv) | |||
| self.fpn_convs.append(fpn_conv) | |||
| # add extra conv layers (e.g., RetinaNet) | |||
| extra_levels = num_outs - self.backbone_end_level + self.start_level | |||
| if self.add_extra_convs and extra_levels >= 1: | |||
| for i in range(extra_levels): | |||
| if i == 0 and self.add_extra_convs == 'on_input': | |||
| in_channels = self.in_channels[self.backbone_end_level - 1] | |||
| else: | |||
| in_channels = out_channels | |||
| extra_fpn_conv = ConvModule_Norm( | |||
| in_channels, | |||
| out_channels, | |||
| 3, | |||
| stride=2, | |||
| padding=1, | |||
| conv_cfg=conv_cfg, | |||
| norm_cfg=norm_cfg, | |||
| act_cfg=act_cfg, | |||
| inplace=False) | |||
| self.fpn_convs.append(extra_fpn_conv) | |||
| @auto_fp16() | |||
| def forward(self, inputs): | |||
| """Forward function.""" | |||
| assert len(inputs) == len(self.in_channels) | |||
| # build laterals | |||
| laterals = [ | |||
| lateral_conv(inputs[i + self.start_level]) | |||
| for i, lateral_conv in enumerate(self.lateral_convs) | |||
| ] | |||
| # build top-down path | |||
| used_backbone_levels = len(laterals) | |||
| if self.use_residual: | |||
| for i in range(used_backbone_levels - 1, 0, -1): | |||
| # In some cases, fixing `scale factor` (e.g. 2) is preferred, but | |||
| # it cannot co-exist with `size` in `F.interpolate`. | |||
| if 'scale_factor' in self.upsample_cfg: | |||
| laterals[i - 1] += F.interpolate(laterals[i], | |||
| **self.upsample_cfg) | |||
| else: | |||
| prev_shape = laterals[i - 1].shape[2:] | |||
| laterals[i - 1] += F.interpolate( | |||
| laterals[i], size=prev_shape, **self.upsample_cfg) | |||
| # build outputs | |||
| # part 1: from original levels | |||
| outs = [ | |||
| self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) | |||
| ] | |||
| # part 2: add extra levels | |||
| if self.num_outs > len(outs): | |||
| # use max pool to get more levels on top of outputs | |||
| # (e.g., Faster R-CNN, Mask R-CNN) | |||
| if not self.add_extra_convs: | |||
| for i in range(self.num_outs - used_backbone_levels): | |||
| outs.append(F.max_pool2d(outs[-1], 1, stride=2)) | |||
| # add conv layers on top of original feature maps (RetinaNet) | |||
| else: | |||
| if self.add_extra_convs == 'on_input': | |||
| extra_source = inputs[self.backbone_end_level - 1] | |||
| elif self.add_extra_convs == 'on_lateral': | |||
| extra_source = laterals[-1] | |||
| elif self.add_extra_convs == 'on_output': | |||
| extra_source = outs[-1] | |||
| else: | |||
| raise NotImplementedError | |||
| outs.append(self.fpn_convs[used_backbone_levels](extra_source)) | |||
| for i in range(used_backbone_levels + 1, self.num_outs): | |||
| if self.relu_before_extra_convs: | |||
| outs.append(self.fpn_convs[i](F.relu(outs[-1]))) | |||
| else: | |||
| outs.append(self.fpn_convs[i](outs[-1])) | |||
| return tuple(outs) | |||
| @@ -0,0 +1,8 @@ | |||
| from .bbox_heads import (ConvFCBBoxNHead, Shared2FCBBoxNHead, | |||
| Shared4Conv1FCBBoxNHead) | |||
| from .mask_heads import FCNMaskNHead | |||
| __all__ = [ | |||
| 'ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead', | |||
| 'FCNMaskNHead' | |||
| ] | |||
| @@ -0,0 +1,4 @@ | |||
| from .convfc_bbox_head import (ConvFCBBoxNHead, Shared2FCBBoxNHead, | |||
| Shared4Conv1FCBBoxNHead) | |||
| __all__ = ['ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead'] | |||
| @@ -0,0 +1,229 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| import torch.nn as nn | |||
| from mmdet.models.builder import HEADS | |||
| from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead | |||
| from mmdet.models.utils import build_linear_layer | |||
| from ...utils import ConvModule_Norm | |||
| @HEADS.register_module() | |||
| class ConvFCBBoxNHead(BBoxHead): | |||
| r"""More general bbox head, with shared conv and fc layers and two optional | |||
| separated branches. | |||
| .. code-block:: none | |||
| /-> cls convs -> cls fcs -> cls | |||
| shared convs -> shared fcs | |||
| \-> reg convs -> reg fcs -> reg | |||
| """ # noqa: W605 | |||
| def __init__(self, | |||
| num_shared_convs=0, | |||
| num_shared_fcs=0, | |||
| num_cls_convs=0, | |||
| num_cls_fcs=0, | |||
| num_reg_convs=0, | |||
| num_reg_fcs=0, | |||
| conv_out_channels=256, | |||
| fc_out_channels=1024, | |||
| conv_cfg=None, | |||
| norm_cfg=None, | |||
| init_cfg=None, | |||
| *args, | |||
| **kwargs): | |||
| super(ConvFCBBoxNHead, self).__init__( | |||
| *args, init_cfg=init_cfg, **kwargs) | |||
| assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs | |||
| + num_reg_convs + num_reg_fcs > 0) | |||
| if num_cls_convs > 0 or num_reg_convs > 0: | |||
| assert num_shared_fcs == 0 | |||
| if not self.with_cls: | |||
| assert num_cls_convs == 0 and num_cls_fcs == 0 | |||
| if not self.with_reg: | |||
| assert num_reg_convs == 0 and num_reg_fcs == 0 | |||
| self.num_shared_convs = num_shared_convs | |||
| self.num_shared_fcs = num_shared_fcs | |||
| self.num_cls_convs = num_cls_convs | |||
| self.num_cls_fcs = num_cls_fcs | |||
| self.num_reg_convs = num_reg_convs | |||
| self.num_reg_fcs = num_reg_fcs | |||
| self.conv_out_channels = conv_out_channels | |||
| self.fc_out_channels = fc_out_channels | |||
| self.conv_cfg = conv_cfg | |||
| self.norm_cfg = norm_cfg | |||
| # add shared convs and fcs | |||
| self.shared_convs, self.shared_fcs, last_layer_dim = \ | |||
| self._add_conv_fc_branch( | |||
| self.num_shared_convs, self.num_shared_fcs, self.in_channels, | |||
| True) | |||
| self.shared_out_channels = last_layer_dim | |||
| # add cls specific branch | |||
| self.cls_convs, self.cls_fcs, self.cls_last_dim = \ | |||
| self._add_conv_fc_branch( | |||
| self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) | |||
| # add reg specific branch | |||
| self.reg_convs, self.reg_fcs, self.reg_last_dim = \ | |||
| self._add_conv_fc_branch( | |||
| self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) | |||
| if self.num_shared_fcs == 0 and not self.with_avg_pool: | |||
| if self.num_cls_fcs == 0: | |||
| self.cls_last_dim *= self.roi_feat_area | |||
| if self.num_reg_fcs == 0: | |||
| self.reg_last_dim *= self.roi_feat_area | |||
| self.relu = nn.ReLU(inplace=True) | |||
| # reconstruct fc_cls and fc_reg since input channels are changed | |||
| if self.with_cls: | |||
| if self.custom_cls_channels: | |||
| cls_channels = self.loss_cls.get_cls_channels(self.num_classes) | |||
| else: | |||
| cls_channels = self.num_classes + 1 | |||
| self.fc_cls = build_linear_layer( | |||
| self.cls_predictor_cfg, | |||
| in_features=self.cls_last_dim, | |||
| out_features=cls_channels) | |||
| if self.with_reg: | |||
| out_dim_reg = (4 if self.reg_class_agnostic else 4 | |||
| * self.num_classes) | |||
| self.fc_reg = build_linear_layer( | |||
| self.reg_predictor_cfg, | |||
| in_features=self.reg_last_dim, | |||
| out_features=out_dim_reg) | |||
| if init_cfg is None: | |||
| # when init_cfg is None, | |||
| # It has been set to | |||
| # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], | |||
| # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] | |||
| # after `super(ConvFCBBoxHead, self).__init__()` | |||
| # we only need to append additional configuration | |||
| # for `shared_fcs`, `cls_fcs` and `reg_fcs` | |||
| self.init_cfg += [ | |||
| dict( | |||
| type='Xavier', | |||
| override=[ | |||
| dict(name='shared_fcs'), | |||
| dict(name='cls_fcs'), | |||
| dict(name='reg_fcs') | |||
| ]) | |||
| ] | |||
| def _add_conv_fc_branch(self, | |||
| num_branch_convs, | |||
| num_branch_fcs, | |||
| in_channels, | |||
| is_shared=False): | |||
| """Add shared or separable branch. | |||
| convs -> avg pool (optional) -> fcs | |||
| """ | |||
| last_layer_dim = in_channels | |||
| # add branch specific conv layers | |||
| branch_convs = nn.ModuleList() | |||
| if num_branch_convs > 0: | |||
| for i in range(num_branch_convs): | |||
| conv_in_channels = ( | |||
| last_layer_dim if i == 0 else self.conv_out_channels) | |||
| branch_convs.append( | |||
| ConvModule_Norm( | |||
| conv_in_channels, | |||
| self.conv_out_channels, | |||
| 3, | |||
| padding=1, | |||
| conv_cfg=self.conv_cfg, | |||
| norm_cfg=self.norm_cfg)) | |||
| last_layer_dim = self.conv_out_channels | |||
| # add branch specific fc layers | |||
| branch_fcs = nn.ModuleList() | |||
| if num_branch_fcs > 0: | |||
| # for shared branch, only consider self.with_avg_pool | |||
| # for separated branches, also consider self.num_shared_fcs | |||
| if (is_shared | |||
| or self.num_shared_fcs == 0) and not self.with_avg_pool: | |||
| last_layer_dim *= self.roi_feat_area | |||
| for i in range(num_branch_fcs): | |||
| fc_in_channels = ( | |||
| last_layer_dim if i == 0 else self.fc_out_channels) | |||
| branch_fcs.append( | |||
| nn.Linear(fc_in_channels, self.fc_out_channels)) | |||
| last_layer_dim = self.fc_out_channels | |||
| return branch_convs, branch_fcs, last_layer_dim | |||
| def forward(self, x): | |||
| # shared part | |||
| if self.num_shared_convs > 0: | |||
| for conv in self.shared_convs: | |||
| x = conv(x) | |||
| if self.num_shared_fcs > 0: | |||
| if self.with_avg_pool: | |||
| x = self.avg_pool(x) | |||
| x = x.flatten(1) | |||
| for fc in self.shared_fcs: | |||
| x = self.relu(fc(x)) | |||
| # separate branches | |||
| x_cls = x | |||
| x_reg = x | |||
| for conv in self.cls_convs: | |||
| x_cls = conv(x_cls) | |||
| if x_cls.dim() > 2: | |||
| if self.with_avg_pool: | |||
| x_cls = self.avg_pool(x_cls) | |||
| x_cls = x_cls.flatten(1) | |||
| for fc in self.cls_fcs: | |||
| x_cls = self.relu(fc(x_cls)) | |||
| for conv in self.reg_convs: | |||
| x_reg = conv(x_reg) | |||
| if x_reg.dim() > 2: | |||
| if self.with_avg_pool: | |||
| x_reg = self.avg_pool(x_reg) | |||
| x_reg = x_reg.flatten(1) | |||
| for fc in self.reg_fcs: | |||
| x_reg = self.relu(fc(x_reg)) | |||
| cls_score = self.fc_cls(x_cls) if self.with_cls else None | |||
| bbox_pred = self.fc_reg(x_reg) if self.with_reg else None | |||
| return cls_score, bbox_pred | |||
| @HEADS.register_module() | |||
| class Shared2FCBBoxNHead(ConvFCBBoxNHead): | |||
| def __init__(self, fc_out_channels=1024, *args, **kwargs): | |||
| super(Shared2FCBBoxNHead, self).__init__( | |||
| num_shared_convs=0, | |||
| num_shared_fcs=2, | |||
| num_cls_convs=0, | |||
| num_cls_fcs=0, | |||
| num_reg_convs=0, | |||
| num_reg_fcs=0, | |||
| fc_out_channels=fc_out_channels, | |||
| *args, | |||
| **kwargs) | |||
| @HEADS.register_module() | |||
| class Shared4Conv1FCBBoxNHead(ConvFCBBoxNHead): | |||
| def __init__(self, fc_out_channels=1024, *args, **kwargs): | |||
| super(Shared4Conv1FCBBoxNHead, self).__init__( | |||
| num_shared_convs=4, | |||
| num_shared_fcs=1, | |||
| num_cls_convs=0, | |||
| num_cls_fcs=0, | |||
| num_reg_convs=0, | |||
| num_reg_fcs=0, | |||
| fc_out_channels=fc_out_channels, | |||
| *args, | |||
| **kwargs) | |||
| @@ -0,0 +1,3 @@ | |||
| from .fcn_mask_head import FCNMaskNHead | |||
| __all__ = ['FCNMaskNHead'] | |||
| @@ -0,0 +1,414 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| # Implementation in this file is modifed from source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| from warnings import warn | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer | |||
| from mmcv.ops.carafe import CARAFEPack | |||
| from mmcv.runner import BaseModule, ModuleList, auto_fp16, force_fp32 | |||
| from mmdet.core import mask_target | |||
| from mmdet.models.builder import HEADS, build_loss | |||
| from torch.nn.modules.utils import _pair | |||
| from ...utils import ConvModule_Norm | |||
| BYTES_PER_FLOAT = 4 | |||
| # TODO: This memory limit may be too much or too little. It would be better to | |||
| # determine it based on available resources. | |||
| GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit | |||
| @HEADS.register_module() | |||
| class FCNMaskNHead(BaseModule): | |||
| def __init__(self, | |||
| num_convs=4, | |||
| roi_feat_size=14, | |||
| in_channels=256, | |||
| conv_kernel_size=3, | |||
| conv_out_channels=256, | |||
| num_classes=80, | |||
| class_agnostic=False, | |||
| upsample_cfg=dict(type='deconv', scale_factor=2), | |||
| conv_cfg=None, | |||
| norm_cfg=None, | |||
| predictor_cfg=dict(type='Conv'), | |||
| loss_mask=dict( | |||
| type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), | |||
| init_cfg=None): | |||
| assert init_cfg is None, 'To prevent abnormal initialization ' \ | |||
| 'behavior, init_cfg is not allowed to be set' | |||
| super(FCNMaskNHead, self).__init__(init_cfg) | |||
| self.upsample_cfg = upsample_cfg.copy() | |||
| if self.upsample_cfg['type'] not in [ | |||
| None, 'deconv', 'nearest', 'bilinear', 'carafe' | |||
| ]: | |||
| raise ValueError( | |||
| f'Invalid upsample method {self.upsample_cfg["type"]}, ' | |||
| 'accepted methods are "deconv", "nearest", "bilinear", ' | |||
| '"carafe"') | |||
| self.num_convs = num_convs | |||
| # WARN: roi_feat_size is reserved and not used | |||
| self.roi_feat_size = _pair(roi_feat_size) | |||
| self.in_channels = in_channels | |||
| self.conv_kernel_size = conv_kernel_size | |||
| self.conv_out_channels = conv_out_channels | |||
| self.upsample_method = self.upsample_cfg.get('type') | |||
| self.scale_factor = self.upsample_cfg.pop('scale_factor', None) | |||
| self.num_classes = num_classes | |||
| self.class_agnostic = class_agnostic | |||
| self.conv_cfg = conv_cfg | |||
| self.norm_cfg = norm_cfg | |||
| self.predictor_cfg = predictor_cfg | |||
| self.fp16_enabled = False | |||
| self.loss_mask = build_loss(loss_mask) | |||
| self.convs = ModuleList() | |||
| for i in range(self.num_convs): | |||
| in_channels = ( | |||
| self.in_channels if i == 0 else self.conv_out_channels) | |||
| padding = (self.conv_kernel_size - 1) // 2 | |||
| self.convs.append( | |||
| ConvModule_Norm( | |||
| in_channels, | |||
| self.conv_out_channels, | |||
| self.conv_kernel_size, | |||
| padding=padding, | |||
| conv_cfg=conv_cfg, | |||
| norm_cfg=norm_cfg)) | |||
| upsample_in_channels = ( | |||
| self.conv_out_channels if self.num_convs > 0 else in_channels) | |||
| upsample_cfg_ = self.upsample_cfg.copy() | |||
| if self.upsample_method is None: | |||
| self.upsample = None | |||
| elif self.upsample_method == 'deconv': | |||
| upsample_cfg_.update( | |||
| in_channels=upsample_in_channels, | |||
| out_channels=self.conv_out_channels, | |||
| kernel_size=self.scale_factor, | |||
| stride=self.scale_factor) | |||
| self.upsample = build_upsample_layer(upsample_cfg_) | |||
| elif self.upsample_method == 'carafe': | |||
| upsample_cfg_.update( | |||
| channels=upsample_in_channels, scale_factor=self.scale_factor) | |||
| self.upsample = build_upsample_layer(upsample_cfg_) | |||
| else: | |||
| # suppress warnings | |||
| align_corners = (None | |||
| if self.upsample_method == 'nearest' else False) | |||
| upsample_cfg_.update( | |||
| scale_factor=self.scale_factor, | |||
| mode=self.upsample_method, | |||
| align_corners=align_corners) | |||
| self.upsample = build_upsample_layer(upsample_cfg_) | |||
| out_channels = 1 if self.class_agnostic else self.num_classes | |||
| logits_in_channel = ( | |||
| self.conv_out_channels | |||
| if self.upsample_method == 'deconv' else upsample_in_channels) | |||
| self.conv_logits = build_conv_layer(self.predictor_cfg, | |||
| logits_in_channel, out_channels, 1) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.debug_imgs = None | |||
| def init_weights(self): | |||
| super(FCNMaskNHead, self).init_weights() | |||
| for m in [self.upsample, self.conv_logits]: | |||
| if m is None: | |||
| continue | |||
| elif isinstance(m, CARAFEPack): | |||
| m.init_weights() | |||
| elif hasattr(m, 'weight') and hasattr(m, 'bias'): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| nn.init.constant_(m.bias, 0) | |||
| @auto_fp16() | |||
| def forward(self, x): | |||
| for conv in self.convs: | |||
| x = conv(x) | |||
| if self.upsample is not None: | |||
| x = self.upsample(x) | |||
| if self.upsample_method == 'deconv': | |||
| x = self.relu(x) | |||
| mask_pred = self.conv_logits(x) | |||
| return mask_pred | |||
| def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg): | |||
| pos_proposals = [res.pos_bboxes for res in sampling_results] | |||
| pos_assigned_gt_inds = [ | |||
| res.pos_assigned_gt_inds for res in sampling_results | |||
| ] | |||
| mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, | |||
| gt_masks, rcnn_train_cfg) | |||
| return mask_targets | |||
| @force_fp32(apply_to=('mask_pred', )) | |||
| def loss(self, mask_pred, mask_targets, labels): | |||
| """ | |||
| Example: | |||
| >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA | |||
| >>> N = 7 # N = number of extracted ROIs | |||
| >>> C, H, W = 11, 32, 32 | |||
| >>> # Create example instance of FCN Mask Head. | |||
| >>> # There are lots of variations depending on the configuration | |||
| >>> self = FCNMaskHead(num_classes=C, num_convs=1) | |||
| >>> inputs = torch.rand(N, self.in_channels, H, W) | |||
| >>> mask_pred = self.forward(inputs) | |||
| >>> sf = self.scale_factor | |||
| >>> labels = torch.randint(0, C, size=(N,)) | |||
| >>> # With the default properties the mask targets should indicate | |||
| >>> # a (potentially soft) single-class label | |||
| >>> mask_targets = torch.rand(N, H * sf, W * sf) | |||
| >>> loss = self.loss(mask_pred, mask_targets, labels) | |||
| >>> print('loss = {!r}'.format(loss)) | |||
| """ | |||
| loss = dict() | |||
| if mask_pred.size(0) == 0: | |||
| loss_mask = mask_pred.sum() | |||
| else: | |||
| if self.class_agnostic: | |||
| loss_mask = self.loss_mask(mask_pred, mask_targets, | |||
| torch.zeros_like(labels)) | |||
| else: | |||
| loss_mask = self.loss_mask(mask_pred, mask_targets, labels) | |||
| loss['loss_mask'] = loss_mask | |||
| return loss | |||
| def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, | |||
| ori_shape, scale_factor, rescale): | |||
| """Get segmentation masks from mask_pred and bboxes. | |||
| Args: | |||
| mask_pred (Tensor or ndarray): shape (n, #class, h, w). | |||
| For single-scale testing, mask_pred is the direct output of | |||
| model, whose type is Tensor, while for multi-scale testing, | |||
| it will be converted to numpy array outside of this method. | |||
| det_bboxes (Tensor): shape (n, 4/5) | |||
| det_labels (Tensor): shape (n, ) | |||
| rcnn_test_cfg (dict): rcnn testing config | |||
| ori_shape (Tuple): original image height and width, shape (2,) | |||
| scale_factor(ndarray | Tensor): If ``rescale is True``, box | |||
| coordinates are divided by this scale factor to fit | |||
| ``ori_shape``. | |||
| rescale (bool): If True, the resulting masks will be rescaled to | |||
| ``ori_shape``. | |||
| Returns: | |||
| list[list]: encoded masks. The c-th item in the outer list | |||
| corresponds to the c-th class. Given the c-th outer list, the | |||
| i-th item in that inner list is the mask for the i-th box with | |||
| class label c. | |||
| Example: | |||
| >>> import mmcv | |||
| >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA | |||
| >>> N = 7 # N = number of extracted ROIs | |||
| >>> C, H, W = 11, 32, 32 | |||
| >>> # Create example instance of FCN Mask Head. | |||
| >>> self = FCNMaskHead(num_classes=C, num_convs=0) | |||
| >>> inputs = torch.rand(N, self.in_channels, H, W) | |||
| >>> mask_pred = self.forward(inputs) | |||
| >>> # Each input is associated with some bounding box | |||
| >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) | |||
| >>> det_labels = torch.randint(0, C, size=(N,)) | |||
| >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, }) | |||
| >>> ori_shape = (H * 4, W * 4) | |||
| >>> scale_factor = torch.FloatTensor((1, 1)) | |||
| >>> rescale = False | |||
| >>> # Encoded masks are a list for each category. | |||
| >>> encoded_masks = self.get_seg_masks( | |||
| >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape, | |||
| >>> scale_factor, rescale | |||
| >>> ) | |||
| >>> assert len(encoded_masks) == C | |||
| >>> assert sum(list(map(len, encoded_masks))) == N | |||
| """ | |||
| if isinstance(mask_pred, torch.Tensor): | |||
| mask_pred = mask_pred.sigmoid() | |||
| else: | |||
| # In AugTest, has been activated before | |||
| mask_pred = det_bboxes.new_tensor(mask_pred) | |||
| device = mask_pred.device | |||
| cls_segms = [[] for _ in range(self.num_classes) | |||
| ] # BG is not included in num_classes | |||
| bboxes = det_bboxes[:, :4] | |||
| labels = det_labels | |||
| # In most cases, scale_factor should have been | |||
| # converted to Tensor when rescale the bbox | |||
| if not isinstance(scale_factor, torch.Tensor): | |||
| if isinstance(scale_factor, float): | |||
| scale_factor = np.array([scale_factor] * 4) | |||
| warn('Scale_factor should be a Tensor or ndarray ' | |||
| 'with shape (4,), float would be deprecated. ') | |||
| assert isinstance(scale_factor, np.ndarray) | |||
| scale_factor = torch.Tensor(scale_factor) | |||
| if rescale: | |||
| img_h, img_w = ori_shape[:2] | |||
| bboxes = bboxes / scale_factor.to(bboxes) | |||
| else: | |||
| w_scale, h_scale = scale_factor[0], scale_factor[1] | |||
| img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32) | |||
| img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32) | |||
| N = len(mask_pred) | |||
| # The actual implementation split the input into chunks, | |||
| # and paste them chunk by chunk. | |||
| if device.type == 'cpu': | |||
| # CPU is most efficient when they are pasted one by one with | |||
| # skip_empty=True, so that it performs minimal number of | |||
| # operations. | |||
| num_chunks = N | |||
| else: | |||
| # GPU benefits from parallelism for larger chunks, | |||
| # but may have memory issue | |||
| # the types of img_w and img_h are np.int32, | |||
| # when the image resolution is large, | |||
| # the calculation of num_chunks will overflow. | |||
| # so we need to change the types of img_w and img_h to int. | |||
| # See https://github.com/open-mmlab/mmdetection/pull/5191 | |||
| num_chunks = int( | |||
| np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT | |||
| / GPU_MEM_LIMIT)) | |||
| # assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it' | |||
| assert num_chunks <= N, 'Default GPU_MEM_LIMIT is too small; try increasing it' | |||
| chunks = torch.chunk(torch.arange(N, device=device), num_chunks) | |||
| threshold = rcnn_test_cfg.mask_thr_binary | |||
| im_mask = torch.zeros( | |||
| N, | |||
| img_h, | |||
| img_w, | |||
| device=device, | |||
| dtype=torch.bool if threshold >= 0 else torch.uint8) | |||
| if not self.class_agnostic: | |||
| mask_pred = mask_pred[range(N), labels][:, None] | |||
| for inds in chunks: | |||
| masks_chunk, spatial_inds = _do_paste_mask( | |||
| mask_pred[inds], | |||
| bboxes[inds], | |||
| img_h, | |||
| img_w, | |||
| skip_empty=device.type == 'cpu') | |||
| if threshold >= 0: | |||
| masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) | |||
| else: | |||
| # for visualization and debugging | |||
| masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) | |||
| im_mask[(inds, ) + spatial_inds] = masks_chunk | |||
| for i in range(N): | |||
| cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy()) | |||
| return cls_segms | |||
| def onnx_export(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, | |||
| ori_shape, **kwargs): | |||
| """Get segmentation masks from mask_pred and bboxes. | |||
| Args: | |||
| mask_pred (Tensor): shape (n, #class, h, w). | |||
| det_bboxes (Tensor): shape (n, 4/5) | |||
| det_labels (Tensor): shape (n, ) | |||
| rcnn_test_cfg (dict): rcnn testing config | |||
| ori_shape (Tuple): original image height and width, shape (2,) | |||
| Returns: | |||
| Tensor: a mask of shape (N, img_h, img_w). | |||
| """ | |||
| mask_pred = mask_pred.sigmoid() | |||
| bboxes = det_bboxes[:, :4] | |||
| labels = det_labels | |||
| # No need to consider rescale and scale_factor while exporting to ONNX | |||
| img_h, img_w = ori_shape[:2] | |||
| threshold = rcnn_test_cfg.mask_thr_binary | |||
| if not self.class_agnostic: | |||
| box_inds = torch.arange(mask_pred.shape[0]) | |||
| mask_pred = mask_pred[box_inds, labels][:, None] | |||
| masks, _ = _do_paste_mask( | |||
| mask_pred, bboxes, img_h, img_w, skip_empty=False) | |||
| if threshold >= 0: | |||
| # should convert to float to avoid problems in TRT | |||
| masks = (masks >= threshold).to(dtype=torch.float) | |||
| return masks | |||
| def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): | |||
| """Paste instance masks according to boxes. | |||
| This implementation is modified from | |||
| https://github.com/facebookresearch/detectron2/ | |||
| Args: | |||
| masks (Tensor): N, 1, H, W | |||
| boxes (Tensor): N, 4 | |||
| img_h (int): Height of the image to be pasted. | |||
| img_w (int): Width of the image to be pasted. | |||
| skip_empty (bool): Only paste masks within the region that | |||
| tightly bound all boxes, and returns the results this region only. | |||
| An important optimization for CPU. | |||
| Returns: | |||
| tuple: (Tensor, tuple). The first item is mask tensor, the second one | |||
| is the slice object. | |||
| If skip_empty == False, the whole image will be pasted. It will | |||
| return a mask of shape (N, img_h, img_w) and an empty tuple. | |||
| If skip_empty == True, only area around the mask will be pasted. | |||
| A mask of shape (N, h', w') and its start and end coordinates | |||
| in the original image will be returned. | |||
| """ | |||
| # On GPU, paste all masks together (up to chunk size) | |||
| # by using the entire image to sample the masks | |||
| # Compared to pasting them one by one, | |||
| # this has more operations but is faster on COCO-scale dataset. | |||
| device = masks.device | |||
| if skip_empty: | |||
| x0_int, y0_int = torch.clamp( | |||
| boxes.min(dim=0).values.floor()[:2] - 1, | |||
| min=0).to(dtype=torch.int32) | |||
| x1_int = torch.clamp( | |||
| boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) | |||
| y1_int = torch.clamp( | |||
| boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) | |||
| else: | |||
| x0_int, y0_int = 0, 0 | |||
| x1_int, y1_int = img_w, img_h | |||
| x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 | |||
| N = masks.shape[0] | |||
| img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 | |||
| img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 | |||
| img_y = (img_y - y0) / (y1 - y0) * 2 - 1 | |||
| img_x = (img_x - x0) / (x1 - x0) * 2 - 1 | |||
| # img_x, img_y have shapes (N, w), (N, h) | |||
| # IsInf op is not supported with ONNX<=1.7.0 | |||
| if not torch.onnx.is_in_onnx_export(): | |||
| if torch.isinf(img_x).any(): | |||
| inds = torch.where(torch.isinf(img_x)) | |||
| img_x[inds] = 0 | |||
| if torch.isinf(img_y).any(): | |||
| inds = torch.where(torch.isinf(img_y)) | |||
| img_y[inds] = 0 | |||
| gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) | |||
| gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) | |||
| grid = torch.stack([gx, gy], dim=3) | |||
| img_masks = F.grid_sample( | |||
| masks.to(dtype=torch.float32), grid, align_corners=False) | |||
| if skip_empty: | |||
| return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) | |||
| else: | |||
| return img_masks[:, 0], () | |||
| @@ -0,0 +1,4 @@ | |||
| from .checkpoint import load_checkpoint | |||
| from .convModule_norm import ConvModule_Norm | |||
| __all__ = ['load_checkpoint', 'ConvModule_Norm'] | |||
| @@ -0,0 +1,558 @@ | |||
| # Copyright (c) Open-MMLab. All rights reserved. | |||
| # Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| import io | |||
| import os | |||
| import os.path as osp | |||
| import pkgutil | |||
| import time | |||
| import warnings | |||
| from collections import OrderedDict | |||
| from importlib import import_module | |||
| from tempfile import TemporaryDirectory | |||
| import mmcv | |||
| import torch | |||
| import torchvision | |||
| from mmcv.fileio import FileClient | |||
| from mmcv.fileio import load as load_file | |||
| from mmcv.parallel import is_module_wrapper | |||
| from mmcv.runner import get_dist_info | |||
| from torch.nn import functional as F | |||
| from torch.optim import Optimizer | |||
| from torch.utils import model_zoo | |||
| def load_state_dict(module, state_dict, strict=False, logger=None): | |||
| """Load state_dict to a module. | |||
| This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |||
| Default value for ``strict`` is set to ``False`` and the message for | |||
| param mismatch will be shown even if strict is False. | |||
| Args: | |||
| module (Module): Module that receives the state_dict. | |||
| state_dict (OrderedDict): Weights. | |||
| strict (bool): whether to strictly enforce that the keys | |||
| in :attr:`state_dict` match the keys returned by this module's | |||
| :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |||
| logger (:obj:`logging.Logger`, optional): Logger to log the error | |||
| message. If not specified, print function will be used. | |||
| """ | |||
| unexpected_keys = [] | |||
| all_missing_keys = [] | |||
| err_msg = [] | |||
| metadata = getattr(state_dict, '_metadata', None) | |||
| state_dict = state_dict.copy() | |||
| if metadata is not None: | |||
| state_dict._metadata = metadata | |||
| # use _load_from_state_dict to enable checkpoint version control | |||
| def load(module, prefix=''): | |||
| # recursively check parallel module in case that the model has a | |||
| # complicated structure, e.g., nn.Module(nn.Module(DDP)) | |||
| if is_module_wrapper(module): | |||
| module = module.module | |||
| local_metadata = {} if metadata is None else metadata.get( | |||
| prefix[:-1], {}) | |||
| module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |||
| all_missing_keys, unexpected_keys, | |||
| err_msg) | |||
| for name, child in module._modules.items(): | |||
| if child is not None: | |||
| load(child, prefix + name + '.') | |||
| load(module) | |||
| load = None # break load->load reference cycle | |||
| missing_keys = [ | |||
| key for key in all_missing_keys if 'num_batches_tracked' not in key | |||
| ] | |||
| if unexpected_keys: | |||
| err_msg.append('unexpected key in source ' | |||
| f'state_dict: {", ".join(unexpected_keys)}\n') | |||
| if missing_keys: | |||
| err_msg.append( | |||
| f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |||
| rank, _ = get_dist_info() | |||
| if len(err_msg) > 0 and rank == 0: | |||
| err_msg.insert( | |||
| 0, 'The model and loaded state dict do not match exactly\n') | |||
| err_msg = '\n'.join(err_msg) | |||
| if strict: | |||
| raise RuntimeError(err_msg) | |||
| elif logger is not None: | |||
| logger.warning(err_msg) | |||
| else: | |||
| print(err_msg) | |||
| print('finish load') | |||
| def load_url_dist(url, model_dir=None): | |||
| """In distributed setting, this function only download checkpoint at local | |||
| rank 0.""" | |||
| rank, world_size = get_dist_info() | |||
| rank = int(os.environ.get('LOCAL_RANK', rank)) | |||
| if rank == 0: | |||
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||
| if world_size > 1: | |||
| torch.distributed.barrier() | |||
| if rank > 0: | |||
| checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||
| return checkpoint | |||
| def load_pavimodel_dist(model_path, map_location=None): | |||
| """In distributed setting, this function only download checkpoint at local | |||
| rank 0.""" | |||
| try: | |||
| from pavi import modelcloud | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'Please install pavi to load checkpoint from modelcloud.') | |||
| rank, world_size = get_dist_info() | |||
| rank = int(os.environ.get('LOCAL_RANK', rank)) | |||
| if rank == 0: | |||
| model = modelcloud.get(model_path) | |||
| with TemporaryDirectory() as tmp_dir: | |||
| downloaded_file = osp.join(tmp_dir, model.name) | |||
| model.download(downloaded_file) | |||
| checkpoint = torch.load(downloaded_file, map_location=map_location) | |||
| if world_size > 1: | |||
| torch.distributed.barrier() | |||
| if rank > 0: | |||
| model = modelcloud.get(model_path) | |||
| with TemporaryDirectory() as tmp_dir: | |||
| downloaded_file = osp.join(tmp_dir, model.name) | |||
| model.download(downloaded_file) | |||
| checkpoint = torch.load( | |||
| downloaded_file, map_location=map_location) | |||
| return checkpoint | |||
| def load_fileclient_dist(filename, backend, map_location): | |||
| """In distributed setting, this function only download checkpoint at local | |||
| rank 0.""" | |||
| rank, world_size = get_dist_info() | |||
| rank = int(os.environ.get('LOCAL_RANK', rank)) | |||
| allowed_backends = ['ceph'] | |||
| if backend not in allowed_backends: | |||
| raise ValueError(f'Load from Backend {backend} is not supported.') | |||
| if rank == 0: | |||
| fileclient = FileClient(backend=backend) | |||
| buffer = io.BytesIO(fileclient.get(filename)) | |||
| checkpoint = torch.load(buffer, map_location=map_location) | |||
| if world_size > 1: | |||
| torch.distributed.barrier() | |||
| if rank > 0: | |||
| fileclient = FileClient(backend=backend) | |||
| buffer = io.BytesIO(fileclient.get(filename)) | |||
| checkpoint = torch.load(buffer, map_location=map_location) | |||
| return checkpoint | |||
| def get_torchvision_models(): | |||
| model_urls = dict() | |||
| for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): | |||
| if ispkg: | |||
| continue | |||
| _zoo = import_module(f'torchvision.models.{name}') | |||
| if hasattr(_zoo, 'model_urls'): | |||
| _urls = getattr(_zoo, 'model_urls') | |||
| model_urls.update(_urls) | |||
| return model_urls | |||
| def get_external_models(): | |||
| mmcv_home = _get_mmcv_home() | |||
| default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') | |||
| default_urls = load_file(default_json_path) | |||
| assert isinstance(default_urls, dict) | |||
| external_json_path = osp.join(mmcv_home, 'open_mmlab.json') | |||
| if osp.exists(external_json_path): | |||
| external_urls = load_file(external_json_path) | |||
| assert isinstance(external_urls, dict) | |||
| default_urls.update(external_urls) | |||
| return default_urls | |||
| def get_mmcls_models(): | |||
| mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') | |||
| mmcls_urls = load_file(mmcls_json_path) | |||
| return mmcls_urls | |||
| def get_deprecated_model_names(): | |||
| deprecate_json_path = osp.join(mmcv.__path__[0], | |||
| 'model_zoo/deprecated.json') | |||
| deprecate_urls = load_file(deprecate_json_path) | |||
| assert isinstance(deprecate_urls, dict) | |||
| return deprecate_urls | |||
| def _process_mmcls_checkpoint(checkpoint): | |||
| state_dict = checkpoint['state_dict'] | |||
| new_state_dict = OrderedDict() | |||
| for k, v in state_dict.items(): | |||
| if k.startswith('backbone.'): | |||
| new_state_dict[k[9:]] = v | |||
| new_checkpoint = dict(state_dict=new_state_dict) | |||
| return new_checkpoint | |||
| def _load_checkpoint(filename, map_location=None): | |||
| """Load checkpoint from somewhere (modelzoo, file, url). | |||
| Args: | |||
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||
| details. | |||
| map_location (str | None): Same as :func:`torch.load`. Default: None. | |||
| Returns: | |||
| dict | OrderedDict: The loaded checkpoint. It can be either an | |||
| OrderedDict storing model weights or a dict containing other | |||
| information, which depends on the checkpoint. | |||
| """ | |||
| if filename.startswith('modelzoo://'): | |||
| warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' | |||
| 'use "torchvision://" instead') | |||
| model_urls = get_torchvision_models() | |||
| model_name = filename[11:] | |||
| checkpoint = load_url_dist(model_urls[model_name]) | |||
| elif filename.startswith('torchvision://'): | |||
| model_urls = get_torchvision_models() | |||
| model_name = filename[14:] | |||
| checkpoint = load_url_dist(model_urls[model_name]) | |||
| elif filename.startswith('open-mmlab://'): | |||
| model_urls = get_external_models() | |||
| model_name = filename[13:] | |||
| deprecated_urls = get_deprecated_model_names() | |||
| if model_name in deprecated_urls: | |||
| warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' | |||
| f'of open-mmlab://{deprecated_urls[model_name]}') | |||
| model_name = deprecated_urls[model_name] | |||
| model_url = model_urls[model_name] | |||
| # check if is url | |||
| if model_url.startswith(('http://', 'https://')): | |||
| checkpoint = load_url_dist(model_url) | |||
| else: | |||
| filename = osp.join(_get_mmcv_home(), model_url) | |||
| if not osp.isfile(filename): | |||
| raise IOError(f'{filename} is not a checkpoint file') | |||
| checkpoint = torch.load(filename, map_location=map_location) | |||
| elif filename.startswith('mmcls://'): | |||
| model_urls = get_mmcls_models() | |||
| model_name = filename[8:] | |||
| checkpoint = load_url_dist(model_urls[model_name]) | |||
| checkpoint = _process_mmcls_checkpoint(checkpoint) | |||
| elif filename.startswith(('http://', 'https://')): | |||
| checkpoint = load_url_dist(filename) | |||
| elif filename.startswith('pavi://'): | |||
| model_path = filename[7:] | |||
| checkpoint = load_pavimodel_dist(model_path, map_location=map_location) | |||
| elif filename.startswith('s3://'): | |||
| checkpoint = load_fileclient_dist( | |||
| filename, backend='ceph', map_location=map_location) | |||
| else: | |||
| if not osp.isfile(filename): | |||
| raise IOError(f'{filename} is not a checkpoint file') | |||
| checkpoint = torch.load(filename, map_location=map_location) | |||
| return checkpoint | |||
| def load_checkpoint(model, | |||
| filename, | |||
| map_location='cpu', | |||
| strict=False, | |||
| logger=None, | |||
| load_ema=True): | |||
| """Load checkpoint from a file or URI. | |||
| Args: | |||
| model (Module): Module to load checkpoint. | |||
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||
| ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||
| details. | |||
| map_location (str): Same as :func:`torch.load`. | |||
| strict (bool): Whether to allow different params for the model and | |||
| checkpoint. | |||
| logger (:mod:`logging.Logger` or None): The logger for error message. | |||
| Returns: | |||
| dict or OrderedDict: The loaded checkpoint. | |||
| """ | |||
| checkpoint = _load_checkpoint(filename, map_location) | |||
| # OrderedDict is a subclass of dict | |||
| if not isinstance(checkpoint, dict): | |||
| raise RuntimeError( | |||
| f'No state_dict found in checkpoint file {filename}') | |||
| # get state_dict from checkpoint | |||
| if load_ema and 'state_dict_ema' in checkpoint: | |||
| state_dict = checkpoint['state_dict_ema'] | |||
| # logger.info(f'loading from state_dict_ema') | |||
| logger.info('loading from state_dict_ema') | |||
| elif 'state_dict' in checkpoint: | |||
| state_dict = checkpoint['state_dict'] | |||
| # logger.info(f'loading from state_dict') | |||
| logger.info('loading from state_dict') | |||
| elif 'model' in checkpoint: | |||
| state_dict = checkpoint['model'] | |||
| # logger.info(f'loading from model') | |||
| logger.info('loading from model') | |||
| print('loading from model') | |||
| else: | |||
| state_dict = checkpoint | |||
| # strip prefix of state_dict | |||
| if list(state_dict.keys())[0].startswith('module.'): | |||
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |||
| # for MoBY, load model of online branch | |||
| if sorted(list(state_dict.keys()))[0].startswith('encoder'): | |||
| state_dict = { | |||
| k.replace('encoder.', ''): v | |||
| for k, v in state_dict.items() if k.startswith('encoder.') | |||
| } | |||
| # reshape absolute position embedding | |||
| if state_dict.get('absolute_pos_embed') is not None: | |||
| absolute_pos_embed = state_dict['absolute_pos_embed'] | |||
| N1, L, C1 = absolute_pos_embed.size() | |||
| N2, C2, H, W = model.absolute_pos_embed.size() | |||
| if N1 != N2 or C1 != C2 or L != H * W: | |||
| logger.warning('Error in loading absolute_pos_embed, pass') | |||
| else: | |||
| state_dict['absolute_pos_embed'] = absolute_pos_embed.view( | |||
| N2, H, W, C2).permute(0, 3, 1, 2) | |||
| all_keys = list(state_dict.keys()) | |||
| for key in all_keys: | |||
| if 'relative_position_index' in key: | |||
| state_dict.pop(key) | |||
| if 'relative_position_bias_table' in key: | |||
| state_dict.pop(key) | |||
| if '.q_bias' in key: | |||
| q_bias = state_dict[key] | |||
| v_bias = state_dict[key.replace('q_bias', 'v_bias')] | |||
| qkv_bias = torch.cat([q_bias, torch.zeros_like(q_bias), v_bias], 0) | |||
| state_dict[key.replace('q_bias', 'qkv.bias')] = qkv_bias | |||
| if '.v.bias' in key: | |||
| continue | |||
| all_keys = list(state_dict.keys()) | |||
| new_state_dict = {} | |||
| for key in all_keys: | |||
| if 'qkv.bias' in key: | |||
| value = state_dict[key] | |||
| dim = value.shape[0] | |||
| selected_dim = (dim * 2) // 3 | |||
| new_state_dict[key.replace( | |||
| 'qkv.bias', 'pos_bias')] = state_dict[key][:selected_dim] | |||
| # interpolate position bias table if needed | |||
| relative_position_bias_table_keys = [ | |||
| k for k in state_dict.keys() if 'relative_position_bias_table' in k | |||
| ] | |||
| for table_key in relative_position_bias_table_keys: | |||
| table_pretrained = state_dict[table_key] | |||
| if table_key not in model.state_dict().keys(): | |||
| logger.warning( | |||
| 'relative_position_bias_table exits in pretrained model but not in current one, pass' | |||
| ) | |||
| continue | |||
| table_current = model.state_dict()[table_key] | |||
| L1, nH1 = table_pretrained.size() | |||
| L2, nH2 = table_current.size() | |||
| if nH1 != nH2: | |||
| logger.warning(f'Error in loading {table_key}, pass') | |||
| else: | |||
| if L1 != L2: | |||
| S1 = int(L1**0.5) | |||
| S2 = int(L2**0.5) | |||
| table_pretrained_resized = F.interpolate( | |||
| table_pretrained.permute(1, 0).view(1, nH1, S1, S1), | |||
| size=(S2, S2), | |||
| mode='bicubic') | |||
| state_dict[table_key] = table_pretrained_resized.view( | |||
| nH2, L2).permute(1, 0) | |||
| rank, _ = get_dist_info() | |||
| if 'pos_embed' in state_dict: | |||
| pos_embed_checkpoint = state_dict['pos_embed'] | |||
| embedding_size = pos_embed_checkpoint.shape[-1] | |||
| H, W = model.patch_embed.patch_shape | |||
| num_patches = model.patch_embed.num_patches | |||
| num_extra_tokens = 1 | |||
| # height (== width) for the checkpoint position embedding | |||
| orig_size = int( | |||
| (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) | |||
| # height (== width) for the new position embedding | |||
| new_size = int(num_patches**0.5) | |||
| # class_token and dist_token are kept unchanged | |||
| if orig_size != new_size: | |||
| if rank == 0: | |||
| print('Position interpolate from %dx%d to %dx%d' % | |||
| (orig_size, orig_size, H, W)) | |||
| # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] | |||
| # only the position tokens are interpolated | |||
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] | |||
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, | |||
| embedding_size).permute( | |||
| 0, 3, 1, 2) | |||
| pos_tokens = torch.nn.functional.interpolate( | |||
| pos_tokens, size=(H, W), mode='bicubic', align_corners=False) | |||
| new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) | |||
| # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) | |||
| state_dict['pos_embed'] = new_pos_embed | |||
| # load state_dict | |||
| load_state_dict(model, state_dict, strict, logger) | |||
| return checkpoint | |||
| def weights_to_cpu(state_dict): | |||
| """Copy a model state_dict to cpu. | |||
| Args: | |||
| state_dict (OrderedDict): Model weights on GPU. | |||
| Returns: | |||
| OrderedDict: Model weights on GPU. | |||
| """ | |||
| state_dict_cpu = OrderedDict() | |||
| for key, val in state_dict.items(): | |||
| state_dict_cpu[key] = val.cpu() | |||
| return state_dict_cpu | |||
| def _save_to_state_dict(module, destination, prefix, keep_vars): | |||
| """Saves module state to `destination` dictionary. | |||
| This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. | |||
| Args: | |||
| module (nn.Module): The module to generate state_dict. | |||
| destination (dict): A dict where state will be stored. | |||
| prefix (str): The prefix for parameters and buffers used in this | |||
| module. | |||
| """ | |||
| for name, param in module._parameters.items(): | |||
| if param is not None: | |||
| destination[prefix + name] = param if keep_vars else param.detach() | |||
| for name, buf in module._buffers.items(): | |||
| # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d | |||
| if buf is not None: | |||
| destination[prefix + name] = buf if keep_vars else buf.detach() | |||
| def get_state_dict(module, destination=None, prefix='', keep_vars=False): | |||
| """Returns a dictionary containing a whole state of the module. | |||
| Both parameters and persistent buffers (e.g. running averages) are | |||
| included. Keys are corresponding parameter and buffer names. | |||
| This method is modified from :meth:`torch.nn.Module.state_dict` to | |||
| recursively check parallel module in case that the model has a complicated | |||
| structure, e.g., nn.Module(nn.Module(DDP)). | |||
| Args: | |||
| module (nn.Module): The module to generate state_dict. | |||
| destination (OrderedDict): Returned dict for the state of the | |||
| module. | |||
| prefix (str): Prefix of the key. | |||
| keep_vars (bool): Whether to keep the variable property of the | |||
| parameters. Default: False. | |||
| Returns: | |||
| dict: A dictionary containing a whole state of the module. | |||
| """ | |||
| # recursively check parallel module in case that the model has a | |||
| # complicated structure, e.g., nn.Module(nn.Module(DDP)) | |||
| if is_module_wrapper(module): | |||
| module = module.module | |||
| # below is the same as torch.nn.Module.state_dict() | |||
| if destination is None: | |||
| destination = OrderedDict() | |||
| destination._metadata = OrderedDict() | |||
| destination._metadata[prefix[:-1]] = local_metadata = dict( | |||
| version=module._version) | |||
| _save_to_state_dict(module, destination, prefix, keep_vars) | |||
| for name, child in module._modules.items(): | |||
| if child is not None: | |||
| get_state_dict( | |||
| child, destination, prefix + name + '.', keep_vars=keep_vars) | |||
| for hook in module._state_dict_hooks.values(): | |||
| hook_result = hook(module, destination, prefix, local_metadata) | |||
| if hook_result is not None: | |||
| destination = hook_result | |||
| return destination | |||
| def save_checkpoint(model, filename, optimizer=None, meta=None): | |||
| """Save checkpoint to file. | |||
| The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | |||
| ``optimizer``. By default ``meta`` will contain version and time info. | |||
| Args: | |||
| model (Module): Module whose params are to be saved. | |||
| filename (str): Checkpoint filename. | |||
| optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | |||
| meta (dict, optional): Metadata to be saved in checkpoint. | |||
| """ | |||
| if meta is None: | |||
| meta = {} | |||
| elif not isinstance(meta, dict): | |||
| raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | |||
| meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) | |||
| if is_module_wrapper(model): | |||
| model = model.module | |||
| if hasattr(model, 'CLASSES') and model.CLASSES is not None: | |||
| # save class name to the meta | |||
| meta.update(CLASSES=model.CLASSES) | |||
| checkpoint = { | |||
| 'meta': meta, | |||
| 'state_dict': weights_to_cpu(get_state_dict(model)) | |||
| } | |||
| # save optimizer state dict in the checkpoint | |||
| if isinstance(optimizer, Optimizer): | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| elif isinstance(optimizer, dict): | |||
| checkpoint['optimizer'] = {} | |||
| for name, optim in optimizer.items(): | |||
| checkpoint['optimizer'][name] = optim.state_dict() | |||
| if filename.startswith('pavi://'): | |||
| try: | |||
| from pavi import modelcloud | |||
| from pavi.exception import NodeNotFoundError | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'Please install pavi to load checkpoint from modelcloud.') | |||
| model_path = filename[7:] | |||
| root = modelcloud.Folder() | |||
| model_dir, model_name = osp.split(model_path) | |||
| try: | |||
| model = modelcloud.get(model_dir) | |||
| except NodeNotFoundError: | |||
| model = root.create_training_model(model_dir) | |||
| with TemporaryDirectory() as tmp_dir: | |||
| checkpoint_file = osp.join(tmp_dir, model_name) | |||
| with open(checkpoint_file, 'wb') as f: | |||
| torch.save(checkpoint, f) | |||
| f.flush() | |||
| model.create_file(checkpoint_file, name=model_name) | |||
| else: | |||
| mmcv.mkdir_or_exist(osp.dirname(filename)) | |||
| # immediately flush buffer | |||
| with open(filename, 'wb') as f: | |||
| torch.save(checkpoint, f) | |||
| f.flush() | |||
| @@ -0,0 +1,30 @@ | |||
| # Implementation adopted from ViTAE-Transformer, source code avaiable via https://github.com/ViTAE-Transformer/ViTDet | |||
| from mmcv.cnn import ConvModule | |||
| class ConvModule_Norm(ConvModule): | |||
| def __init__(self, in_channels, out_channels, kernel, **kwargs): | |||
| super().__init__(in_channels, out_channels, kernel, **kwargs) | |||
| self.normType = kwargs.get('norm_cfg', {'type': ''}) | |||
| if self.normType is not None: | |||
| self.normType = self.normType['type'] | |||
| def forward(self, x, activate=True, norm=True): | |||
| for layer in self.order: | |||
| if layer == 'conv': | |||
| if self.with_explicit_padding: | |||
| x = self.padding_layer(x) | |||
| x = self.conv(x) | |||
| elif layer == 'norm' and norm and self.with_norm: | |||
| if 'LN' in self.normType: | |||
| x = x.permute(0, 2, 3, 1) | |||
| x = self.norm(x) | |||
| x = x.permute(0, 3, 1, 2).contiguous() | |||
| else: | |||
| x = self.norm(x) | |||
| elif layer == 'act' and activate and self.with_activation: | |||
| x = self.activate(x) | |||
| return x | |||
| @@ -62,6 +62,7 @@ class Pipeline(ABC): | |||
| model: Union[InputModel, List[InputModel]] = None, | |||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||
| device: str = 'gpu', | |||
| auto_collate=True, | |||
| **kwargs): | |||
| """ Base class for pipeline. | |||
| @@ -74,6 +75,7 @@ class Pipeline(ABC): | |||
| model: (list of) Model name or model object | |||
| preprocessor: (list of) Preprocessor object | |||
| device (str): gpu device or cpu device to use | |||
| auto_collate (bool): automatically to convert data to tensor or not. | |||
| """ | |||
| if config_file is not None: | |||
| self.cfg = Config.from_file(config_file) | |||
| @@ -98,6 +100,7 @@ class Pipeline(ABC): | |||
| self.device = create_device(self.device_name == 'cpu') | |||
| self._model_prepare = False | |||
| self._model_prepare_lock = Lock() | |||
| self._auto_collate = auto_collate | |||
| def prepare_model(self): | |||
| self._model_prepare_lock.acquire(timeout=600) | |||
| @@ -252,7 +255,7 @@ class Pipeline(ABC): | |||
| return self._collate_fn(torch.from_numpy(data)) | |||
| elif isinstance(data, torch.Tensor): | |||
| return data.to(self.device) | |||
| elif isinstance(data, (str, int, float, bool)): | |||
| elif isinstance(data, (str, int, float, bool, type(None))): | |||
| return data | |||
| elif isinstance(data, InputFeatures): | |||
| return data | |||
| @@ -270,7 +273,7 @@ class Pipeline(ABC): | |||
| out = self.preprocess(input, **preprocess_params) | |||
| with self.place_device(): | |||
| if self.framework == Frameworks.torch: | |||
| if self.framework == Frameworks.torch and self._auto_collate: | |||
| with torch.no_grad(): | |||
| out = self._collate_fn(out) | |||
| out = self.forward(out, **forward_params) | |||
| @@ -35,6 +35,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| ), # TODO: revise back after passing the pr | |||
| Tasks.image_matting: (Pipelines.image_matting, | |||
| 'damo/cv_unet_image-matting'), | |||
| Tasks.human_detection: (Pipelines.human_detection, | |||
| 'damo/cv_resnet18_human-detection'), | |||
| Tasks.object_detection: (Pipelines.object_detection, | |||
| 'damo/cv_vit_object-detection_coco'), | |||
| Tasks.image_denoise: (Pipelines.image_denoise, | |||
| 'damo/cv_nafnet_image-denoise_sidd'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
| from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
| from .object_detection_pipeline import ObjectDetectionPipeline | |||
| from .face_detection_pipeline import FaceDetectionPipeline | |||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| @@ -30,6 +31,7 @@ else: | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
| 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
| 'object_detection_pipeline': ['ObjectDetectionPipeline'], | |||
| 'face_detection_pipeline': ['FaceDetectionPipeline'], | |||
| 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
| 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | |||
| @@ -0,0 +1,51 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @PIPELINES.register_module( | |||
| Tasks.human_detection, module_name=Pipelines.human_detection) | |||
| @PIPELINES.register_module( | |||
| Tasks.object_detection, module_name=Pipelines.object_detection) | |||
| class ObjectDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img.astype(np.float) | |||
| img = self.model.preprocess(img) | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| outputs = self.model.inference(input['img']) | |||
| result = {'data': outputs} | |||
| return result | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | |||
| if bboxes is None: | |||
| return None | |||
| outputs = { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.BOXES: bboxes | |||
| } | |||
| return outputs | |||
| @@ -20,6 +20,7 @@ class CVTasks(object): | |||
| image_classification = 'image-classification' | |||
| image_tagging = 'image-tagging' | |||
| object_detection = 'object-detection' | |||
| human_detection = 'human-detection' | |||
| image_segmentation = 'image-segmentation' | |||
| image_editing = 'image-editing' | |||
| image_generation = 'image-generation' | |||
| @@ -1,4 +1,5 @@ | |||
| decord>=0.6.0 | |||
| easydict | |||
| tf_slim | |||
| timm | |||
| torchvision | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| class ObjectDetectionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_object_detection(self): | |||
| input_location = 'data/test/images/image_detection.jpg' | |||
| model_id = 'damo/cv_vit_object-detection_coco' | |||
| object_detect = pipeline(Tasks.object_detection, model=model_id) | |||
| result = object_detect(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_object_detection_with_default_task(self): | |||
| input_location = 'data/test/images/image_detection.jpg' | |||
| object_detect = pipeline(Tasks.object_detection) | |||
| result = object_detect(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_human_detection(self): | |||
| input_location = 'data/test/images/image_detection.jpg' | |||
| model_id = 'damo/cv_resnet18_human-detection' | |||
| human_detect = pipeline(Tasks.human_detection, model=model_id) | |||
| result = human_detect(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_human_detection_with_default_task(self): | |||
| input_location = 'data/test/images/image_detection.jpg' | |||
| human_detect = pipeline(Tasks.human_detection) | |||
| result = human_detect(input_location) | |||
| if result: | |||
| print(result) | |||
| else: | |||
| raise ValueError('process error') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||