Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10260332master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487 | |||
| size 23889 | |||
| @@ -40,6 +40,7 @@ class Models(object): | |||
| ulfd = 'ulfd' | |||
| video_inpainting = 'video-inpainting' | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| @@ -181,6 +182,7 @@ class Pipelines(object): | |||
| video_inpainting = 'video-inpainting' | |||
| pst_action_recognition = 'patchshift-action-recognition' | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .det_infer import NanoDetForFaceHumanHandDetection | |||
| else: | |||
| _import_structure = {'det_infer': ['NanoDetForFaceHumanHandDetection']} | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .one_stage_detector import OneStageDetector | |||
| logger = get_logger() | |||
| def load_model_weight(model_dir, device): | |||
| checkpoint = torch.load( | |||
| '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
| map_location=device) | |||
| state_dict = checkpoint['state_dict'].copy() | |||
| for k in checkpoint['state_dict']: | |||
| if k.startswith('avg_model.'): | |||
| v = state_dict.pop(k) | |||
| state_dict[k[4:]] = v | |||
| return state_dict | |||
| @MODELS.register_module( | |||
| Tasks.face_human_hand_detection, | |||
| module_name=Models.face_human_hand_detection) | |||
| class NanoDetForFaceHumanHandDetection(TorchModel): | |||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
| super().__init__( | |||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
| self.model = OneStageDetector() | |||
| if torch.cuda.is_available(): | |||
| self.device = 'cuda' | |||
| logger.info('Use GPU ') | |||
| else: | |||
| self.device = 'cpu' | |||
| logger.info('Use CPU') | |||
| self.state_dict = load_model_weight(model_dir, self.device) | |||
| self.model.load_state_dict(self.state_dict, strict=False) | |||
| self.model.eval() | |||
| self.model.to(self.device) | |||
| def forward(self, x): | |||
| pred_result = self.model.inference(x) | |||
| return pred_result | |||
| def naive_collate(batch): | |||
| elem = batch[0] | |||
| if isinstance(elem, dict): | |||
| return {key: naive_collate([d[key] for d in batch]) for key in elem} | |||
| else: | |||
| return batch | |||
| def get_resize_matrix(raw_shape, dst_shape): | |||
| r_w, r_h = raw_shape | |||
| d_w, d_h = dst_shape | |||
| Rs = np.eye(3) | |||
| Rs[0, 0] *= d_w / r_w | |||
| Rs[1, 1] *= d_h / r_h | |||
| return Rs | |||
| def color_aug_and_norm(meta, mean, std): | |||
| img = meta['img'].astype(np.float32) / 255 | |||
| mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) / 255 | |||
| std = np.array(std, dtype=np.float32).reshape(1, 1, 3) / 255 | |||
| img = (img - mean) / std | |||
| meta['img'] = img | |||
| return meta | |||
| def img_process(meta, mean, std): | |||
| raw_img = meta['img'] | |||
| height = raw_img.shape[0] | |||
| width = raw_img.shape[1] | |||
| dst_shape = [320, 320] | |||
| M = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) | |||
| ResizeM = get_resize_matrix((width, height), dst_shape) | |||
| M = ResizeM @ M | |||
| img = cv2.warpPerspective(raw_img, M, dsize=tuple(dst_shape)) | |||
| meta['img'] = img | |||
| meta['warp_matrix'] = M | |||
| meta = color_aug_and_norm(meta, mean, std) | |||
| return meta | |||
| def overlay_bbox_cv(dets, class_names, score_thresh): | |||
| all_box = [] | |||
| for label in dets: | |||
| for bbox in dets[label]: | |||
| score = bbox[-1] | |||
| if score > score_thresh: | |||
| x0, y0, x1, y1 = [int(i) for i in bbox[:4]] | |||
| all_box.append([label, x0, y0, x1, y1, score]) | |||
| all_box.sort(key=lambda v: v[5]) | |||
| return all_box | |||
| mean = [103.53, 116.28, 123.675] | |||
| std = [57.375, 57.12, 58.395] | |||
| class_names = ['person', 'face', 'hand'] | |||
| def inference(model, device, img_path): | |||
| img_info = {'id': 0} | |||
| img = cv2.imread(img_path) | |||
| height, width = img.shape[:2] | |||
| img_info['height'] = height | |||
| img_info['width'] = width | |||
| meta = dict(img_info=img_info, raw_img=img, img=img) | |||
| meta = img_process(meta, mean, std) | |||
| meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).to(device) | |||
| meta = naive_collate([meta]) | |||
| meta['img'] = (meta['img'][0]).reshape(1, 3, 320, 320) | |||
| with torch.no_grad(): | |||
| res = model(meta) | |||
| result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35) | |||
| return result | |||
| @@ -0,0 +1,395 @@ | |||
| # The implementation here is modified based on nanodet, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| from .utils import ConvModule, DepthwiseConvModule, act_layers | |||
| def _make_divisible(v, divisor, min_value=None): | |||
| if min_value is None: | |||
| min_value = divisor | |||
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |||
| # Make sure that round down does not go down by more than 10%. | |||
| if new_v < 0.9 * v: | |||
| new_v += divisor | |||
| return new_v | |||
| def hard_sigmoid(x, inplace: bool = False): | |||
| if inplace: | |||
| return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) | |||
| else: | |||
| return F.relu6(x + 3.0) / 6.0 | |||
| class SqueezeExcite(nn.Module): | |||
| def __init__(self, | |||
| in_chs, | |||
| se_ratio=0.25, | |||
| reduced_base_chs=None, | |||
| activation='ReLU', | |||
| gate_fn=hard_sigmoid, | |||
| divisor=4, | |||
| **_): | |||
| super(SqueezeExcite, self).__init__() | |||
| self.gate_fn = gate_fn | |||
| reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, | |||
| divisor) | |||
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
| self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) | |||
| self.act1 = act_layers(activation) | |||
| self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) | |||
| def forward(self, x): | |||
| x_se = self.avg_pool(x) | |||
| x_se = self.conv_reduce(x_se) | |||
| x_se = self.act1(x_se) | |||
| x_se = self.conv_expand(x_se) | |||
| x = x * self.gate_fn(x_se) | |||
| return x | |||
| class GhostModule(nn.Module): | |||
| def __init__(self, | |||
| inp, | |||
| oup, | |||
| kernel_size=1, | |||
| ratio=2, | |||
| dw_size=3, | |||
| stride=1, | |||
| activation='ReLU'): | |||
| super(GhostModule, self).__init__() | |||
| self.oup = oup | |||
| init_channels = math.ceil(oup / ratio) | |||
| new_channels = init_channels * (ratio - 1) | |||
| self.primary_conv = nn.Sequential( | |||
| nn.Conv2d( | |||
| inp, | |||
| init_channels, | |||
| kernel_size, | |||
| stride, | |||
| kernel_size // 2, | |||
| bias=False), | |||
| nn.BatchNorm2d(init_channels), | |||
| act_layers(activation) if activation else nn.Sequential(), | |||
| ) | |||
| self.cheap_operation = nn.Sequential( | |||
| nn.Conv2d( | |||
| init_channels, | |||
| new_channels, | |||
| dw_size, | |||
| 1, | |||
| dw_size // 2, | |||
| groups=init_channels, | |||
| bias=False, | |||
| ), | |||
| nn.BatchNorm2d(new_channels), | |||
| act_layers(activation) if activation else nn.Sequential(), | |||
| ) | |||
| def forward(self, x): | |||
| x1 = self.primary_conv(x) | |||
| x2 = self.cheap_operation(x1) | |||
| out = torch.cat([x1, x2], dim=1) | |||
| return out | |||
| class GhostBottleneck(nn.Module): | |||
| """Ghost bottleneck w/ optional SE""" | |||
| def __init__( | |||
| self, | |||
| in_chs, | |||
| mid_chs, | |||
| out_chs, | |||
| dw_kernel_size=3, | |||
| stride=1, | |||
| activation='ReLU', | |||
| se_ratio=0.0, | |||
| ): | |||
| super(GhostBottleneck, self).__init__() | |||
| has_se = se_ratio is not None and se_ratio > 0.0 | |||
| self.stride = stride | |||
| # Point-wise expansion | |||
| self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation) | |||
| # Depth-wise convolution | |||
| if self.stride > 1: | |||
| self.conv_dw = nn.Conv2d( | |||
| mid_chs, | |||
| mid_chs, | |||
| dw_kernel_size, | |||
| stride=stride, | |||
| padding=(dw_kernel_size - 1) // 2, | |||
| groups=mid_chs, | |||
| bias=False, | |||
| ) | |||
| self.bn_dw = nn.BatchNorm2d(mid_chs) | |||
| if has_se: | |||
| self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) | |||
| else: | |||
| self.se = None | |||
| self.ghost2 = GhostModule(mid_chs, out_chs, activation=None) | |||
| if in_chs == out_chs and self.stride == 1: | |||
| self.shortcut = nn.Sequential() | |||
| else: | |||
| self.shortcut = nn.Sequential( | |||
| nn.Conv2d( | |||
| in_chs, | |||
| in_chs, | |||
| dw_kernel_size, | |||
| stride=stride, | |||
| padding=(dw_kernel_size - 1) // 2, | |||
| groups=in_chs, | |||
| bias=False, | |||
| ), | |||
| nn.BatchNorm2d(in_chs), | |||
| nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), | |||
| nn.BatchNorm2d(out_chs), | |||
| ) | |||
| def forward(self, x): | |||
| residual = x | |||
| x = self.ghost1(x) | |||
| if self.stride > 1: | |||
| x = self.conv_dw(x) | |||
| x = self.bn_dw(x) | |||
| if self.se is not None: | |||
| x = self.se(x) | |||
| x = self.ghost2(x) | |||
| x += self.shortcut(residual) | |||
| return x | |||
| class GhostBlocks(nn.Module): | |||
| """Stack of GhostBottleneck used in GhostPAN. | |||
| Args: | |||
| in_channels (int): Number of input channels. | |||
| out_channels (int): Number of output channels. | |||
| expand (int): Expand ratio of GhostBottleneck. Default: 1. | |||
| kernel_size (int): Kernel size of depthwise convolution. Default: 5. | |||
| num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. | |||
| use_res (bool): Whether to use residual connection. Default: False. | |||
| activation (str): Name of activation function. Default: LeakyReLU. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| expand=1, | |||
| kernel_size=5, | |||
| num_blocks=1, | |||
| use_res=False, | |||
| activation='LeakyReLU', | |||
| ): | |||
| super(GhostBlocks, self).__init__() | |||
| self.use_res = use_res | |||
| if use_res: | |||
| self.reduce_conv = ConvModule( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| activation=activation, | |||
| ) | |||
| blocks = [] | |||
| for _ in range(num_blocks): | |||
| blocks.append( | |||
| GhostBottleneck( | |||
| in_channels, | |||
| int(out_channels * expand), | |||
| out_channels, | |||
| dw_kernel_size=kernel_size, | |||
| activation=activation, | |||
| )) | |||
| self.blocks = nn.Sequential(*blocks) | |||
| def forward(self, x): | |||
| out = self.blocks(x) | |||
| if self.use_res: | |||
| out = out + self.reduce_conv(x) | |||
| return out | |||
| class GhostPAN(nn.Module): | |||
| """Path Aggregation Network with Ghost block. | |||
| Args: | |||
| in_channels (List[int]): Number of input channels per scale. | |||
| out_channels (int): Number of output channels (used at each scale) | |||
| num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3 | |||
| use_depthwise (bool): Whether to depthwise separable convolution in | |||
| blocks. Default: False | |||
| kernel_size (int): Kernel size of depthwise convolution. Default: 5. | |||
| expand (int): Expand ratio of GhostBottleneck. Default: 1. | |||
| num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. | |||
| use_res (bool): Whether to use residual connection. Default: False. | |||
| num_extra_level (int): Number of extra conv layers for more feature levels. | |||
| Default: 0. | |||
| upsample_cfg (dict): Config dict for interpolate layer. | |||
| Default: `dict(scale_factor=2, mode='nearest')` | |||
| norm_cfg (dict): Config dict for normalization layer. | |||
| Default: dict(type='BN') | |||
| activation (str): Activation layer name. | |||
| Default: LeakyReLU. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| use_depthwise=False, | |||
| kernel_size=5, | |||
| expand=1, | |||
| num_blocks=1, | |||
| use_res=False, | |||
| num_extra_level=0, | |||
| upsample_cfg=dict(scale_factor=2, mode='bilinear'), | |||
| norm_cfg=dict(type='BN'), | |||
| activation='LeakyReLU', | |||
| ): | |||
| super(GhostPAN, self).__init__() | |||
| assert num_extra_level >= 0 | |||
| assert num_blocks >= 1 | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| conv = DepthwiseConvModule if use_depthwise else ConvModule | |||
| # build top-down blocks | |||
| self.upsample = nn.Upsample(**upsample_cfg) | |||
| self.reduce_layers = nn.ModuleList() | |||
| for idx in range(len(in_channels)): | |||
| self.reduce_layers.append( | |||
| ConvModule( | |||
| in_channels[idx], | |||
| out_channels, | |||
| 1, | |||
| norm_cfg=norm_cfg, | |||
| activation=activation, | |||
| )) | |||
| self.top_down_blocks = nn.ModuleList() | |||
| for idx in range(len(in_channels) - 1, 0, -1): | |||
| self.top_down_blocks.append( | |||
| GhostBlocks( | |||
| out_channels * 2, | |||
| out_channels, | |||
| expand, | |||
| kernel_size=kernel_size, | |||
| num_blocks=num_blocks, | |||
| use_res=use_res, | |||
| activation=activation, | |||
| )) | |||
| # build bottom-up blocks | |||
| self.downsamples = nn.ModuleList() | |||
| self.bottom_up_blocks = nn.ModuleList() | |||
| for idx in range(len(in_channels) - 1): | |||
| self.downsamples.append( | |||
| conv( | |||
| out_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=2, | |||
| padding=kernel_size // 2, | |||
| norm_cfg=norm_cfg, | |||
| activation=activation, | |||
| )) | |||
| self.bottom_up_blocks.append( | |||
| GhostBlocks( | |||
| out_channels * 2, | |||
| out_channels, | |||
| expand, | |||
| kernel_size=kernel_size, | |||
| num_blocks=num_blocks, | |||
| use_res=use_res, | |||
| activation=activation, | |||
| )) | |||
| # extra layers | |||
| self.extra_lvl_in_conv = nn.ModuleList() | |||
| self.extra_lvl_out_conv = nn.ModuleList() | |||
| for i in range(num_extra_level): | |||
| self.extra_lvl_in_conv.append( | |||
| conv( | |||
| out_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=2, | |||
| padding=kernel_size // 2, | |||
| norm_cfg=norm_cfg, | |||
| activation=activation, | |||
| )) | |||
| self.extra_lvl_out_conv.append( | |||
| conv( | |||
| out_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=2, | |||
| padding=kernel_size // 2, | |||
| norm_cfg=norm_cfg, | |||
| activation=activation, | |||
| )) | |||
| def forward(self, inputs): | |||
| """ | |||
| Args: | |||
| inputs (tuple[Tensor]): input features. | |||
| Returns: | |||
| tuple[Tensor]: multi level features. | |||
| """ | |||
| assert len(inputs) == len(self.in_channels) | |||
| inputs = [ | |||
| reduce(input_x) | |||
| for input_x, reduce in zip(inputs, self.reduce_layers) | |||
| ] | |||
| # top-down path | |||
| inner_outs = [inputs[-1]] | |||
| for idx in range(len(self.in_channels) - 1, 0, -1): | |||
| feat_heigh = inner_outs[0] | |||
| feat_low = inputs[idx - 1] | |||
| inner_outs[0] = feat_heigh | |||
| upsample_feat = self.upsample(feat_heigh) | |||
| inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( | |||
| torch.cat([upsample_feat, feat_low], 1)) | |||
| inner_outs.insert(0, inner_out) | |||
| # bottom-up path | |||
| outs = [inner_outs[0]] | |||
| for idx in range(len(self.in_channels) - 1): | |||
| feat_low = outs[-1] | |||
| feat_height = inner_outs[idx + 1] | |||
| downsample_feat = self.downsamples[idx](feat_low) | |||
| out = self.bottom_up_blocks[idx]( | |||
| torch.cat([downsample_feat, feat_height], 1)) | |||
| outs.append(out) | |||
| # extra layers | |||
| for extra_in_layer, extra_out_layer in zip(self.extra_lvl_in_conv, | |||
| self.extra_lvl_out_conv): | |||
| outs.append(extra_in_layer(inputs[-1]) + extra_out_layer(outs[-1])) | |||
| return tuple(outs) | |||
| @@ -0,0 +1,427 @@ | |||
| # The implementation here is modified based on nanodet, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet | |||
| import math | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torchvision.ops import nms | |||
| from .utils import ConvModule, DepthwiseConvModule | |||
| class Integral(nn.Module): | |||
| """A fixed layer for calculating integral result from distribution. | |||
| This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, | |||
| P(y_i) denotes the softmax vector that represents the discrete distribution | |||
| y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} | |||
| Args: | |||
| reg_max (int): The maximal value of the discrete set. Default: 16. You | |||
| may want to reset it according to your new dataset or related | |||
| settings. | |||
| """ | |||
| def __init__(self, reg_max=16): | |||
| super(Integral, self).__init__() | |||
| self.reg_max = reg_max | |||
| self.register_buffer('project', | |||
| torch.linspace(0, self.reg_max, self.reg_max + 1)) | |||
| def forward(self, x): | |||
| """Forward feature from the regression head to get integral result of | |||
| bounding box location. | |||
| Args: | |||
| x (Tensor): Features of the regression head, shape (N, 4*(n+1)), | |||
| n is self.reg_max. | |||
| Returns: | |||
| x (Tensor): Integral result of box locations, i.e., distance | |||
| offsets from the box center in four directions, shape (N, 4). | |||
| """ | |||
| shape = x.size() | |||
| x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) | |||
| x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4) | |||
| return x | |||
| def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): | |||
| """Performs non-maximum suppression in a batched fashion. | |||
| Modified from https://github.com/pytorch/vision/blob | |||
| /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. | |||
| In order to perform NMS independently per class, we add an offset to all | |||
| the boxes. The offset is dependent only on the class idx, and is large | |||
| enough so that boxes from different classes do not overlap. | |||
| Arguments: | |||
| boxes (torch.Tensor): boxes in shape (N, 4). | |||
| scores (torch.Tensor): scores in shape (N, ). | |||
| idxs (torch.Tensor): each index value correspond to a bbox cluster, | |||
| and NMS will not be applied between elements of different idxs, | |||
| shape (N, ). | |||
| nms_cfg (dict): specify nms type and other parameters like iou_thr. | |||
| Possible keys includes the following. | |||
| - iou_thr (float): IoU threshold used for NMS. | |||
| - split_thr (float): threshold number of boxes. In some cases the | |||
| number of boxes is large (e.g., 200k). To avoid OOM during | |||
| training, the users could set `split_thr` to a small value. | |||
| If the number of boxes is greater than the threshold, it will | |||
| perform NMS on each group of boxes separately and sequentially. | |||
| Defaults to 10000. | |||
| class_agnostic (bool): if true, nms is class agnostic, | |||
| i.e. IoU thresholding happens over all boxes, | |||
| regardless of the predicted class. | |||
| Returns: | |||
| tuple: kept dets and indice. | |||
| """ | |||
| nms_cfg_ = nms_cfg.copy() | |||
| class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) | |||
| if class_agnostic: | |||
| boxes_for_nms = boxes | |||
| else: | |||
| max_coordinate = boxes.max() | |||
| offsets = idxs.to(boxes) * (max_coordinate + 1) | |||
| boxes_for_nms = boxes + offsets[:, None] | |||
| nms_cfg_.pop('type', 'nms') | |||
| split_thr = nms_cfg_.pop('split_thr', 10000) | |||
| if len(boxes_for_nms) < split_thr: | |||
| keep = nms(boxes_for_nms, scores, **nms_cfg_) | |||
| boxes = boxes[keep] | |||
| scores = scores[keep] | |||
| else: | |||
| total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) | |||
| for id in torch.unique(idxs): | |||
| mask = (idxs == id).nonzero(as_tuple=False).view(-1) | |||
| keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_) | |||
| total_mask[mask[keep]] = True | |||
| keep = total_mask.nonzero(as_tuple=False).view(-1) | |||
| keep = keep[scores[keep].argsort(descending=True)] | |||
| boxes = boxes[keep] | |||
| scores = scores[keep] | |||
| return torch.cat([boxes, scores[:, None]], -1), keep | |||
| def multiclass_nms(multi_bboxes, | |||
| multi_scores, | |||
| score_thr, | |||
| nms_cfg, | |||
| max_num=-1, | |||
| score_factors=None): | |||
| """NMS for multi-class bboxes. | |||
| Args: | |||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
| multi_scores (Tensor): shape (n, #class), where the last column | |||
| contains scores of the background class, but this will be ignored. | |||
| score_thr (float): bbox threshold, bboxes with scores lower than it | |||
| will not be considered. | |||
| nms_thr (float): NMS IoU threshold | |||
| max_num (int): if there are more than max_num bboxes after NMS, | |||
| only top max_num will be kept. | |||
| score_factors (Tensor): The factors multiplied to scores before | |||
| applying NMS | |||
| Returns: | |||
| tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ | |||
| are 0-based. | |||
| """ | |||
| num_classes = multi_scores.size(1) - 1 | |||
| if multi_bboxes.shape[1] > 4: | |||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
| else: | |||
| bboxes = multi_bboxes[:, None].expand( | |||
| multi_scores.size(0), num_classes, 4) | |||
| scores = multi_scores[:, :-1] | |||
| valid_mask = scores > score_thr | |||
| bboxes = torch.masked_select( | |||
| bboxes, | |||
| torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), | |||
| -1)).view(-1, 4) | |||
| if score_factors is not None: | |||
| scores = scores * score_factors[:, None] | |||
| scores = torch.masked_select(scores, valid_mask) | |||
| labels = valid_mask.nonzero(as_tuple=False)[:, 1] | |||
| if bboxes.numel() == 0: | |||
| bboxes = multi_bboxes.new_zeros((0, 5)) | |||
| labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) | |||
| if torch.onnx.is_in_onnx_export(): | |||
| raise RuntimeError('[ONNX Error] Can not record NMS ' | |||
| 'as it has not been executed this time') | |||
| return bboxes, labels | |||
| dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) | |||
| if max_num > 0: | |||
| dets = dets[:max_num] | |||
| keep = keep[:max_num] | |||
| return dets, labels[keep] | |||
| def distance2bbox(points, distance, max_shape=None): | |||
| """Decode distance prediction to bounding box. | |||
| Args: | |||
| points (Tensor): Shape (n, 2), [x, y]. | |||
| distance (Tensor): Distance from the given point to 4 | |||
| boundaries (left, top, right, bottom). | |||
| max_shape (tuple): Shape of the image. | |||
| Returns: | |||
| Tensor: Decoded bboxes. | |||
| """ | |||
| x1 = points[..., 0] - distance[..., 0] | |||
| y1 = points[..., 1] - distance[..., 1] | |||
| x2 = points[..., 0] + distance[..., 2] | |||
| y2 = points[..., 1] + distance[..., 3] | |||
| if max_shape is not None: | |||
| x1 = x1.clamp(min=0, max=max_shape[1]) | |||
| y1 = y1.clamp(min=0, max=max_shape[0]) | |||
| x2 = x2.clamp(min=0, max=max_shape[1]) | |||
| y2 = y2.clamp(min=0, max=max_shape[0]) | |||
| return torch.stack([x1, y1, x2, y2], -1) | |||
| def warp_boxes(boxes, M, width, height): | |||
| n = len(boxes) | |||
| if n: | |||
| xy = np.ones((n * 4, 3)) | |||
| xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) | |||
| xy = xy @ M.T | |||
| xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) | |||
| x = xy[:, [0, 2, 4, 6]] | |||
| y = xy[:, [1, 3, 5, 7]] | |||
| xy = np.concatenate( | |||
| (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T | |||
| xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) | |||
| xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) | |||
| return xy.astype(np.float32) | |||
| else: | |||
| return boxes | |||
| class NanoDetPlusHead(nn.Module): | |||
| """Detection head used in NanoDet-Plus. | |||
| Args: | |||
| num_classes (int): Number of categories excluding the background | |||
| category. | |||
| loss (dict): Loss config. | |||
| input_channel (int): Number of channels of the input feature. | |||
| feat_channels (int): Number of channels of the feature. | |||
| Default: 96. | |||
| stacked_convs (int): Number of conv layers in the stacked convs. | |||
| Default: 2. | |||
| kernel_size (int): Size of the convolving kernel. Default: 5. | |||
| strides (list[int]): Strides of input multi-level feature maps. | |||
| Default: [8, 16, 32]. | |||
| conv_type (str): Type of the convolution. | |||
| Default: "DWConv". | |||
| norm_cfg (dict): Dictionary to construct and config norm layer. | |||
| Default: dict(type='BN'). | |||
| reg_max (int): The maximal value of the discrete set. Default: 7. | |||
| activation (str): Type of activation function. Default: "LeakyReLU". | |||
| assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13). | |||
| """ | |||
| def __init__(self, | |||
| num_classes, | |||
| input_channel, | |||
| feat_channels=96, | |||
| stacked_convs=2, | |||
| kernel_size=5, | |||
| strides=[8, 16, 32], | |||
| conv_type='DWConv', | |||
| norm_cfg=dict(type='BN'), | |||
| reg_max=7, | |||
| activation='LeakyReLU', | |||
| assigner_cfg=dict(topk=13), | |||
| **kwargs): | |||
| super(NanoDetPlusHead, self).__init__() | |||
| self.num_classes = num_classes | |||
| self.in_channels = input_channel | |||
| self.feat_channels = feat_channels | |||
| self.stacked_convs = stacked_convs | |||
| self.kernel_size = kernel_size | |||
| self.strides = strides | |||
| self.reg_max = reg_max | |||
| self.activation = activation | |||
| self.ConvModule = ConvModule if conv_type == 'Conv' else DepthwiseConvModule | |||
| self.norm_cfg = norm_cfg | |||
| self.distribution_project = Integral(self.reg_max) | |||
| self._init_layers() | |||
| def _init_layers(self): | |||
| self.cls_convs = nn.ModuleList() | |||
| for _ in self.strides: | |||
| cls_convs = self._buid_not_shared_head() | |||
| self.cls_convs.append(cls_convs) | |||
| self.gfl_cls = nn.ModuleList([ | |||
| nn.Conv2d( | |||
| self.feat_channels, | |||
| self.num_classes + 4 * (self.reg_max + 1), | |||
| 1, | |||
| padding=0, | |||
| ) for _ in self.strides | |||
| ]) | |||
| def _buid_not_shared_head(self): | |||
| cls_convs = nn.ModuleList() | |||
| for i in range(self.stacked_convs): | |||
| chn = self.in_channels if i == 0 else self.feat_channels | |||
| cls_convs.append( | |||
| self.ConvModule( | |||
| chn, | |||
| self.feat_channels, | |||
| self.kernel_size, | |||
| stride=1, | |||
| padding=self.kernel_size // 2, | |||
| norm_cfg=self.norm_cfg, | |||
| bias=self.norm_cfg is None, | |||
| activation=self.activation, | |||
| )) | |||
| return cls_convs | |||
| def forward(self, feats): | |||
| if torch.onnx.is_in_onnx_export(): | |||
| return self._forward_onnx(feats) | |||
| outputs = [] | |||
| for feat, cls_convs, gfl_cls in zip( | |||
| feats, | |||
| self.cls_convs, | |||
| self.gfl_cls, | |||
| ): | |||
| for conv in cls_convs: | |||
| feat = conv(feat) | |||
| output = gfl_cls(feat) | |||
| outputs.append(output.flatten(start_dim=2)) | |||
| outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) | |||
| return outputs | |||
| def post_process(self, preds, meta): | |||
| """Prediction results post processing. Decode bboxes and rescale | |||
| to original image size. | |||
| Args: | |||
| preds (Tensor): Prediction output. | |||
| meta (dict): Meta info. | |||
| """ | |||
| cls_scores, bbox_preds = preds.split( | |||
| [self.num_classes, 4 * (self.reg_max + 1)], dim=-1) | |||
| result_list = self.get_bboxes(cls_scores, bbox_preds, meta) | |||
| det_results = {} | |||
| warp_matrixes = ( | |||
| meta['warp_matrix'] | |||
| if isinstance(meta['warp_matrix'], list) else meta['warp_matrix']) | |||
| img_heights = ( | |||
| meta['img_info']['height'].cpu().numpy() if isinstance( | |||
| meta['img_info']['height'], torch.Tensor) else | |||
| meta['img_info']['height']) | |||
| img_widths = ( | |||
| meta['img_info']['width'].cpu().numpy() if isinstance( | |||
| meta['img_info']['width'], torch.Tensor) else | |||
| meta['img_info']['width']) | |||
| img_ids = ( | |||
| meta['img_info']['id'].cpu().numpy() if isinstance( | |||
| meta['img_info']['id'], torch.Tensor) else | |||
| meta['img_info']['id']) | |||
| for result, img_width, img_height, img_id, warp_matrix in zip( | |||
| result_list, img_widths, img_heights, img_ids, warp_matrixes): | |||
| det_result = {} | |||
| det_bboxes, det_labels = result | |||
| det_bboxes = det_bboxes.detach().cpu().numpy() | |||
| det_bboxes[:, :4] = warp_boxes(det_bboxes[:, :4], | |||
| np.linalg.inv(warp_matrix), | |||
| img_width, img_height) | |||
| classes = det_labels.detach().cpu().numpy() | |||
| for i in range(self.num_classes): | |||
| inds = classes == i | |||
| det_result[i] = np.concatenate( | |||
| [ | |||
| det_bboxes[inds, :4].astype(np.float32), | |||
| det_bboxes[inds, 4:5].astype(np.float32), | |||
| ], | |||
| axis=1, | |||
| ).tolist() | |||
| det_results[img_id] = det_result | |||
| return det_results | |||
| def get_bboxes(self, cls_preds, reg_preds, img_metas): | |||
| """Decode the outputs to bboxes. | |||
| Args: | |||
| cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). | |||
| reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). | |||
| img_metas (dict): Dict of image info. | |||
| Returns: | |||
| results_list (list[tuple]): List of detection bboxes and labels. | |||
| """ | |||
| device = cls_preds.device | |||
| b = cls_preds.shape[0] | |||
| input_height, input_width = img_metas['img'].shape[2:] | |||
| input_shape = (input_height, input_width) | |||
| featmap_sizes = [(math.ceil(input_height / stride), | |||
| math.ceil(input_width) / stride) | |||
| for stride in self.strides] | |||
| mlvl_center_priors = [ | |||
| self.get_single_level_center_priors( | |||
| b, | |||
| featmap_sizes[i], | |||
| stride, | |||
| dtype=torch.float32, | |||
| device=device, | |||
| ) for i, stride in enumerate(self.strides) | |||
| ] | |||
| center_priors = torch.cat(mlvl_center_priors, dim=1) | |||
| dis_preds = self.distribution_project(reg_preds) * center_priors[..., | |||
| 2, | |||
| None] | |||
| bboxes = distance2bbox( | |||
| center_priors[..., :2], dis_preds, max_shape=input_shape) | |||
| scores = cls_preds.sigmoid() | |||
| result_list = [] | |||
| for i in range(b): | |||
| score, bbox = scores[i], bboxes[i] | |||
| padding = score.new_zeros(score.shape[0], 1) | |||
| score = torch.cat([score, padding], dim=1) | |||
| results = multiclass_nms( | |||
| bbox, | |||
| score, | |||
| score_thr=0.05, | |||
| nms_cfg=dict(type='nms', iou_threshold=0.6), | |||
| max_num=100, | |||
| ) | |||
| result_list.append(results) | |||
| return result_list | |||
| def get_single_level_center_priors(self, batch_size, featmap_size, stride, | |||
| dtype, device): | |||
| """Generate centers of a single stage feature map. | |||
| Args: | |||
| batch_size (int): Number of images in one batch. | |||
| featmap_size (tuple[int]): height and width of the feature map | |||
| stride (int): down sample stride of the feature map | |||
| dtype (obj:`torch.dtype`): data type of the tensors | |||
| device (obj:`torch.device`): device of the tensors | |||
| Return: | |||
| priors (Tensor): center priors of a single level feature map. | |||
| """ | |||
| h, w = featmap_size | |||
| x_range = (torch.arange(w, dtype=dtype, device=device)) * stride | |||
| y_range = (torch.arange(h, dtype=dtype, device=device)) * stride | |||
| y, x = torch.meshgrid(y_range, x_range) | |||
| y = y.flatten() | |||
| x = x.flatten() | |||
| strides = x.new_full((x.shape[0], ), stride) | |||
| proiors = torch.stack([x, y, strides, strides], dim=-1) | |||
| return proiors.unsqueeze(0).repeat(batch_size, 1, 1) | |||
| @@ -0,0 +1,64 @@ | |||
| # The implementation here is modified based on nanodet, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet | |||
| import torch | |||
| import torch.nn as nn | |||
| from .ghost_pan import GhostPAN | |||
| from .nanodet_plus_head import NanoDetPlusHead | |||
| from .shufflenetv2 import ShuffleNetV2 | |||
| class OneStageDetector(nn.Module): | |||
| def __init__(self): | |||
| super(OneStageDetector, self).__init__() | |||
| self.backbone = ShuffleNetV2( | |||
| model_size='1.0x', | |||
| out_stages=(2, 3, 4), | |||
| with_last_conv=False, | |||
| kernal_size=3, | |||
| activation='LeakyReLU', | |||
| pretrain=False) | |||
| self.fpn = GhostPAN( | |||
| in_channels=[116, 232, 464], | |||
| out_channels=96, | |||
| use_depthwise=True, | |||
| kernel_size=5, | |||
| expand=1, | |||
| num_blocks=1, | |||
| use_res=False, | |||
| num_extra_level=1, | |||
| upsample_cfg=dict(scale_factor=2, mode='bilinear'), | |||
| norm_cfg=dict(type='BN'), | |||
| activation='LeakyReLU') | |||
| self.head = NanoDetPlusHead( | |||
| num_classes=3, | |||
| input_channel=96, | |||
| feat_channels=96, | |||
| stacked_convs=2, | |||
| kernel_size=5, | |||
| strides=[8, 16, 32, 64], | |||
| conv_type='DWConv', | |||
| norm_cfg=dict(type='BN'), | |||
| reg_max=7, | |||
| activation='LeakyReLU', | |||
| assigner_cfg=dict(topk=13)) | |||
| self.epoch = 0 | |||
| def forward(self, x): | |||
| x = self.backbone(x) | |||
| if hasattr(self, 'fpn'): | |||
| x = self.fpn(x) | |||
| if hasattr(self, 'head'): | |||
| x = self.head(x) | |||
| return x | |||
| def inference(self, meta): | |||
| with torch.no_grad(): | |||
| torch.cuda.synchronize() | |||
| preds = self(meta['img']) | |||
| torch.cuda.synchronize() | |||
| results = self.head.post_process(preds, meta) | |||
| torch.cuda.synchronize() | |||
| return results | |||
| @@ -0,0 +1,182 @@ | |||
| # The implementation here is modified based on nanodet, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet | |||
| import torch | |||
| import torch.nn as nn | |||
| from .utils import act_layers | |||
| def channel_shuffle(x, groups): | |||
| batchsize, num_channels, height, width = x.data.size() | |||
| channels_per_group = num_channels // groups | |||
| x = x.view(batchsize, groups, channels_per_group, height, width) | |||
| x = torch.transpose(x, 1, 2).contiguous() | |||
| x = x.view(batchsize, -1, height, width) | |||
| return x | |||
| class ShuffleV2Block(nn.Module): | |||
| def __init__(self, inp, oup, stride, activation='ReLU'): | |||
| super(ShuffleV2Block, self).__init__() | |||
| if not (1 <= stride <= 3): | |||
| raise ValueError('illegal stride value') | |||
| self.stride = stride | |||
| branch_features = oup // 2 | |||
| assert (self.stride != 1) or (inp == branch_features << 1) | |||
| if self.stride > 1: | |||
| self.branch1 = nn.Sequential( | |||
| self.depthwise_conv( | |||
| inp, inp, kernel_size=3, stride=self.stride, padding=1), | |||
| nn.BatchNorm2d(inp), | |||
| nn.Conv2d( | |||
| inp, | |||
| branch_features, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=False), | |||
| nn.BatchNorm2d(branch_features), | |||
| act_layers(activation), | |||
| ) | |||
| else: | |||
| self.branch1 = nn.Sequential() | |||
| self.branch2 = nn.Sequential( | |||
| nn.Conv2d( | |||
| inp if (self.stride > 1) else branch_features, | |||
| branch_features, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=False, | |||
| ), | |||
| nn.BatchNorm2d(branch_features), | |||
| act_layers(activation), | |||
| self.depthwise_conv( | |||
| branch_features, | |||
| branch_features, | |||
| kernel_size=3, | |||
| stride=self.stride, | |||
| padding=1, | |||
| ), | |||
| nn.BatchNorm2d(branch_features), | |||
| nn.Conv2d( | |||
| branch_features, | |||
| branch_features, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=False, | |||
| ), | |||
| nn.BatchNorm2d(branch_features), | |||
| act_layers(activation), | |||
| ) | |||
| @staticmethod | |||
| def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): | |||
| return nn.Conv2d( | |||
| i, o, kernel_size, stride, padding, bias=bias, groups=i) | |||
| def forward(self, x): | |||
| if self.stride == 1: | |||
| x1, x2 = x.chunk(2, dim=1) | |||
| out = torch.cat((x1, self.branch2(x2)), dim=1) | |||
| else: | |||
| out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) | |||
| out = channel_shuffle(out, 2) | |||
| return out | |||
| class ShuffleNetV2(nn.Module): | |||
| def __init__( | |||
| self, | |||
| model_size='1.5x', | |||
| out_stages=(2, 3, 4), | |||
| with_last_conv=False, | |||
| kernal_size=3, | |||
| activation='ReLU', | |||
| pretrain=True, | |||
| ): | |||
| super(ShuffleNetV2, self).__init__() | |||
| assert set(out_stages).issubset((2, 3, 4)) | |||
| print('model size is ', model_size) | |||
| self.stage_repeats = [4, 8, 4] | |||
| self.model_size = model_size | |||
| self.out_stages = out_stages | |||
| self.with_last_conv = with_last_conv | |||
| self.kernal_size = kernal_size | |||
| self.activation = activation | |||
| if model_size == '0.5x': | |||
| self._stage_out_channels = [24, 48, 96, 192, 1024] | |||
| elif model_size == '1.0x': | |||
| self._stage_out_channels = [24, 116, 232, 464, 1024] | |||
| elif model_size == '1.5x': | |||
| self._stage_out_channels = [24, 176, 352, 704, 1024] | |||
| elif model_size == '2.0x': | |||
| self._stage_out_channels = [24, 244, 488, 976, 2048] | |||
| else: | |||
| raise NotImplementedError | |||
| # building first layer | |||
| input_channels = 3 | |||
| output_channels = self._stage_out_channels[0] | |||
| self.conv1 = nn.Sequential( | |||
| nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), | |||
| nn.BatchNorm2d(output_channels), | |||
| act_layers(activation), | |||
| ) | |||
| input_channels = output_channels | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] | |||
| for name, repeats, output_channels in zip( | |||
| stage_names, self.stage_repeats, self._stage_out_channels[1:]): | |||
| seq = [ | |||
| ShuffleV2Block( | |||
| input_channels, output_channels, 2, activation=activation) | |||
| ] | |||
| for i in range(repeats - 1): | |||
| seq.append( | |||
| ShuffleV2Block( | |||
| output_channels, | |||
| output_channels, | |||
| 1, | |||
| activation=activation)) | |||
| setattr(self, name, nn.Sequential(*seq)) | |||
| input_channels = output_channels | |||
| output_channels = self._stage_out_channels[-1] | |||
| if self.with_last_conv: | |||
| conv5 = nn.Sequential( | |||
| nn.Conv2d( | |||
| input_channels, output_channels, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(output_channels), | |||
| act_layers(activation), | |||
| ) | |||
| self.stage4.add_module('conv5', conv5) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.maxpool(x) | |||
| output = [] | |||
| for i in range(2, 5): | |||
| stage = getattr(self, 'stage{}'.format(i)) | |||
| x = stage(x) | |||
| if i in self.out_stages: | |||
| output.append(x) | |||
| return tuple(output) | |||
| @@ -0,0 +1,277 @@ | |||
| # The implementation here is modified based on nanodet, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet | |||
| import torch | |||
| import torch.nn as nn | |||
| activations = { | |||
| 'ReLU': nn.ReLU, | |||
| 'LeakyReLU': nn.LeakyReLU, | |||
| 'ReLU6': nn.ReLU6, | |||
| 'SELU': nn.SELU, | |||
| 'ELU': nn.ELU, | |||
| 'GELU': nn.GELU, | |||
| 'PReLU': nn.PReLU, | |||
| 'SiLU': nn.SiLU, | |||
| 'HardSwish': nn.Hardswish, | |||
| 'Hardswish': nn.Hardswish, | |||
| None: nn.Identity, | |||
| } | |||
| def act_layers(name): | |||
| assert name in activations.keys() | |||
| if name == 'LeakyReLU': | |||
| return nn.LeakyReLU(negative_slope=0.1, inplace=True) | |||
| elif name == 'GELU': | |||
| return nn.GELU() | |||
| elif name == 'PReLU': | |||
| return nn.PReLU() | |||
| else: | |||
| return activations[name](inplace=True) | |||
| norm_cfg = { | |||
| 'BN': ('bn', nn.BatchNorm2d), | |||
| 'SyncBN': ('bn', nn.SyncBatchNorm), | |||
| 'GN': ('gn', nn.GroupNorm), | |||
| } | |||
| def build_norm_layer(cfg, num_features, postfix=''): | |||
| """Build normalization layer | |||
| Args: | |||
| cfg (dict): cfg should contain: | |||
| type (str): identify norm layer type. | |||
| layer args: args needed to instantiate a norm layer. | |||
| requires_grad (bool): [optional] whether stop gradient updates | |||
| num_features (int): number of channels from input. | |||
| postfix (int, str): appended into norm abbreviation to | |||
| create named layer. | |||
| Returns: | |||
| name (str): abbreviation + postfix | |||
| layer (nn.Module): created norm layer | |||
| """ | |||
| assert isinstance(cfg, dict) and 'type' in cfg | |||
| cfg_ = cfg.copy() | |||
| layer_type = cfg_.pop('type') | |||
| if layer_type not in norm_cfg: | |||
| raise KeyError('Unrecognized norm type {}'.format(layer_type)) | |||
| else: | |||
| abbr, norm_layer = norm_cfg[layer_type] | |||
| if norm_layer is None: | |||
| raise NotImplementedError | |||
| assert isinstance(postfix, (int, str)) | |||
| name = abbr + str(postfix) | |||
| requires_grad = cfg_.pop('requires_grad', True) | |||
| cfg_.setdefault('eps', 1e-5) | |||
| if layer_type != 'GN': | |||
| layer = norm_layer(num_features, **cfg_) | |||
| if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): | |||
| layer._specify_ddp_gpu_num(1) | |||
| else: | |||
| assert 'num_groups' in cfg_ | |||
| layer = norm_layer(num_channels=num_features, **cfg_) | |||
| for param in layer.parameters(): | |||
| param.requires_grad = requires_grad | |||
| return name, layer | |||
| class ConvModule(nn.Module): | |||
| """A conv block that contains conv/norm/activation layers. | |||
| Args: | |||
| in_channels (int): Same as nn.Conv2d. | |||
| out_channels (int): Same as nn.Conv2d. | |||
| kernel_size (int or tuple[int]): Same as nn.Conv2d. | |||
| stride (int or tuple[int]): Same as nn.Conv2d. | |||
| padding (int or tuple[int]): Same as nn.Conv2d. | |||
| dilation (int or tuple[int]): Same as nn.Conv2d. | |||
| groups (int): Same as nn.Conv2d. | |||
| bias (bool or str): If specified as `auto`, it will be decided by the | |||
| norm_cfg. Bias will be set as True if norm_cfg is None, otherwise | |||
| False. | |||
| conv_cfg (dict): Config dict for convolution layer. | |||
| norm_cfg (dict): Config dict for normalization layer. | |||
| activation (str): activation layer, "ReLU" by default. | |||
| inplace (bool): Whether to use inplace mode for activation. | |||
| order (tuple[str]): The order of conv/norm/activation layers. It is a | |||
| sequence of "conv", "norm" and "act". Examples are | |||
| ("conv", "norm", "act") and ("act", "conv", "norm"). | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| groups=1, | |||
| bias='auto', | |||
| conv_cfg=None, | |||
| norm_cfg=None, | |||
| activation='ReLU', | |||
| inplace=True, | |||
| order=('conv', 'norm', 'act'), | |||
| ): | |||
| super(ConvModule, self).__init__() | |||
| assert conv_cfg is None or isinstance(conv_cfg, dict) | |||
| assert norm_cfg is None or isinstance(norm_cfg, dict) | |||
| assert activation is None or isinstance(activation, str) | |||
| self.conv_cfg = conv_cfg | |||
| self.norm_cfg = norm_cfg | |||
| self.activation = activation | |||
| self.inplace = inplace | |||
| self.order = order | |||
| assert isinstance(self.order, tuple) and len(self.order) == 3 | |||
| assert set(order) == {'conv', 'norm', 'act'} | |||
| self.with_norm = norm_cfg is not None | |||
| if bias == 'auto': | |||
| bias = False if self.with_norm else True | |||
| self.with_bias = bias | |||
| if self.with_norm and self.with_bias: | |||
| warnings.warn('ConvModule has norm and bias at the same time') | |||
| self.conv = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| self.in_channels = self.conv.in_channels | |||
| self.out_channels = self.conv.out_channels | |||
| self.kernel_size = self.conv.kernel_size | |||
| self.stride = self.conv.stride | |||
| self.padding = self.conv.padding | |||
| self.dilation = self.conv.dilation | |||
| self.transposed = self.conv.transposed | |||
| self.output_padding = self.conv.output_padding | |||
| self.groups = self.conv.groups | |||
| if self.with_norm: | |||
| if order.index('norm') > order.index('conv'): | |||
| norm_channels = out_channels | |||
| else: | |||
| norm_channels = in_channels | |||
| self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) | |||
| self.add_module(self.norm_name, norm) | |||
| else: | |||
| self.norm_name = None | |||
| if self.activation: | |||
| self.act = act_layers(self.activation) | |||
| @property | |||
| def norm(self): | |||
| if self.norm_name: | |||
| return getattr(self, self.norm_name) | |||
| else: | |||
| return None | |||
| def forward(self, x, norm=True): | |||
| for layer in self.order: | |||
| if layer == 'conv': | |||
| x = self.conv(x) | |||
| elif layer == 'norm' and norm and self.with_norm: | |||
| x = self.norm(x) | |||
| elif layer == 'act' and self.activation: | |||
| x = self.act(x) | |||
| return x | |||
| class DepthwiseConvModule(nn.Module): | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| bias='auto', | |||
| norm_cfg=dict(type='BN'), | |||
| activation='ReLU', | |||
| inplace=True, | |||
| order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'), | |||
| ): | |||
| super(DepthwiseConvModule, self).__init__() | |||
| assert activation is None or isinstance(activation, str) | |||
| self.activation = activation | |||
| self.inplace = inplace | |||
| self.order = order | |||
| assert isinstance(self.order, tuple) and len(self.order) == 6 | |||
| assert set(order) == { | |||
| 'depthwise', | |||
| 'dwnorm', | |||
| 'act', | |||
| 'pointwise', | |||
| 'pwnorm', | |||
| 'act', | |||
| } | |||
| self.with_norm = norm_cfg is not None | |||
| if bias == 'auto': | |||
| bias = False if self.with_norm else True | |||
| self.with_bias = bias | |||
| if self.with_norm and self.with_bias: | |||
| warnings.warn('ConvModule has norm and bias at the same time') | |||
| self.depthwise = nn.Conv2d( | |||
| in_channels, | |||
| in_channels, | |||
| kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=in_channels, | |||
| bias=bias, | |||
| ) | |||
| self.pointwise = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=bias) | |||
| self.in_channels = self.depthwise.in_channels | |||
| self.out_channels = self.pointwise.out_channels | |||
| self.kernel_size = self.depthwise.kernel_size | |||
| self.stride = self.depthwise.stride | |||
| self.padding = self.depthwise.padding | |||
| self.dilation = self.depthwise.dilation | |||
| self.transposed = self.depthwise.transposed | |||
| self.output_padding = self.depthwise.output_padding | |||
| if self.with_norm: | |||
| _, self.dwnorm = build_norm_layer(norm_cfg, in_channels) | |||
| _, self.pwnorm = build_norm_layer(norm_cfg, out_channels) | |||
| if self.activation: | |||
| self.act = act_layers(self.activation) | |||
| def forward(self, x, norm=True): | |||
| for layer_name in self.order: | |||
| if layer_name != 'act': | |||
| layer = self.__getattr__(layer_name) | |||
| x = layer(x) | |||
| elif layer_name == 'act' and self.activation: | |||
| x = self.act(x) | |||
| return x | |||
| @@ -649,8 +649,17 @@ TASK_OUTPUTS = { | |||
| # 'output': ['Done' / 'Decode_Error'] | |||
| # } | |||
| Tasks.video_inpainting: [OutputKeys.OUTPUT], | |||
| # { | |||
| # 'output': ['bixin'] | |||
| # } | |||
| Tasks.hand_static: [OutputKeys.OUTPUT] | |||
| Tasks.hand_static: [OutputKeys.OUTPUT], | |||
| # { | |||
| # 'output': [ | |||
| # [2, 75, 287, 240, 510, 0.8335018754005432], | |||
| # [1, 127, 83, 332, 366, 0.9175254702568054], | |||
| # [0, 0, 0, 367, 639, 0.9693422317504883]] | |||
| # } | |||
| Tasks.face_human_hand_detection: [OutputKeys.OUTPUT], | |||
| } | |||
| @@ -183,6 +183,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_video-inpainting'), | |||
| Tasks.hand_static: (Pipelines.hand_static, | |||
| 'damo/cv_mobileface_hand-static'), | |||
| Tasks.face_human_hand_detection: | |||
| (Pipelines.face_human_hand_detection, | |||
| 'damo/cv_nanodet_face-human-hand-detection'), | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_human_hand_detection import det_infer | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_human_hand_detection, | |||
| module_name=Pipelines.face_human_hand_detection) | |||
| class NanoDettForFaceHumanHandDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create face-human-hand detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = det_infer.inference(self.model, self.device, | |||
| input['input_path']) | |||
| logger.info(result) | |||
| return {OutputKeys.OUTPUT: result} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -43,6 +43,7 @@ class CVTasks(object): | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| shop_segmentation = 'shop-segmentation' | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| # image editing | |||
| skin_retouching = 'skin-retouching' | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| class FaceHumanHandTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_nanodet_face-human-hand-detection' | |||
| self.input = { | |||
| 'input_path': 'data/test/images/face_human_hand_detection.jpg', | |||
| } | |||
| def pipeline_inference(self, pipeline: Pipeline, input: str): | |||
| result = pipeline(input) | |||
| logger.info(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_human_hand_detection = pipeline( | |||
| Tasks.face_human_hand_detection, model=self.model_id) | |||
| self.pipeline_inference(face_human_hand_detection, self.input) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| face_human_hand_detection = pipeline(Tasks.face_human_hand_detection) | |||
| self.pipeline_inference(face_human_hand_detection, self.input) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||