diff --git a/data/test/images/face_human_hand_detection.jpg b/data/test/images/face_human_hand_detection.jpg new file mode 100644 index 00000000..f94bb547 --- /dev/null +++ b/data/test/images/face_human_hand_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487 +size 23889 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 17b1dc40..54e09f7a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -40,6 +40,7 @@ class Models(object): ulfd = 'ulfd' video_inpainting = 'video-inpainting' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' # EasyCV models yolox = 'YOLOX' @@ -181,6 +182,7 @@ class Pipelines(object): video_inpainting = 'video-inpainting' pst_action_recognition = 'patchshift-action-recognition' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/face_human_hand_detection/__init__.py b/modelscope/models/cv/face_human_hand_detection/__init__.py new file mode 100644 index 00000000..33a5fd2f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .det_infer import NanoDetForFaceHumanHandDetection + +else: + _import_structure = {'det_infer': ['NanoDetForFaceHumanHandDetection']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_human_hand_detection/det_infer.py b/modelscope/models/cv/face_human_hand_detection/det_infer.py new file mode 100644 index 00000000..7a7225ee --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/det_infer.py @@ -0,0 +1,133 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .one_stage_detector import OneStageDetector + +logger = get_logger() + + +def load_model_weight(model_dir, device): + checkpoint = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = checkpoint['state_dict'].copy() + for k in checkpoint['state_dict']: + if k.startswith('avg_model.'): + v = state_dict.pop(k) + state_dict[k[4:]] = v + + return state_dict + + +@MODELS.register_module( + Tasks.face_human_hand_detection, + module_name=Models.face_human_hand_detection) +class NanoDetForFaceHumanHandDetection(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = OneStageDetector() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU ') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.state_dict = load_model_weight(model_dir, self.device) + self.model.load_state_dict(self.state_dict, strict=False) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model.inference(x) + return pred_result + + +def naive_collate(batch): + elem = batch[0] + if isinstance(elem, dict): + return {key: naive_collate([d[key] for d in batch]) for key in elem} + else: + return batch + + +def get_resize_matrix(raw_shape, dst_shape): + + r_w, r_h = raw_shape + d_w, d_h = dst_shape + Rs = np.eye(3) + + Rs[0, 0] *= d_w / r_w + Rs[1, 1] *= d_h / r_h + return Rs + + +def color_aug_and_norm(meta, mean, std): + img = meta['img'].astype(np.float32) / 255 + mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) / 255 + std = np.array(std, dtype=np.float32).reshape(1, 1, 3) / 255 + img = (img - mean) / std + meta['img'] = img + return meta + + +def img_process(meta, mean, std): + raw_img = meta['img'] + height = raw_img.shape[0] + width = raw_img.shape[1] + dst_shape = [320, 320] + M = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + ResizeM = get_resize_matrix((width, height), dst_shape) + M = ResizeM @ M + img = cv2.warpPerspective(raw_img, M, dsize=tuple(dst_shape)) + meta['img'] = img + meta['warp_matrix'] = M + meta = color_aug_and_norm(meta, mean, std) + return meta + + +def overlay_bbox_cv(dets, class_names, score_thresh): + all_box = [] + for label in dets: + for bbox in dets[label]: + score = bbox[-1] + if score > score_thresh: + x0, y0, x1, y1 = [int(i) for i in bbox[:4]] + all_box.append([label, x0, y0, x1, y1, score]) + all_box.sort(key=lambda v: v[5]) + return all_box + + +mean = [103.53, 116.28, 123.675] +std = [57.375, 57.12, 58.395] +class_names = ['person', 'face', 'hand'] + + +def inference(model, device, img_path): + img_info = {'id': 0} + img = cv2.imread(img_path) + height, width = img.shape[:2] + img_info['height'] = height + img_info['width'] = width + meta = dict(img_info=img_info, raw_img=img, img=img) + + meta = img_process(meta, mean, std) + meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).to(device) + meta = naive_collate([meta]) + meta['img'] = (meta['img'][0]).reshape(1, 3, 320, 320) + with torch.no_grad(): + res = model(meta) + result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35) + return result diff --git a/modelscope/models/cv/face_human_hand_detection/ghost_pan.py b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py new file mode 100644 index 00000000..e00de407 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py @@ -0,0 +1,395 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import torch +import torch.nn as nn + +from .utils import ConvModule, DepthwiseConvModule, act_layers + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class SqueezeExcite(nn.Module): + + def __init__(self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + activation='ReLU', + gate_fn=hard_sigmoid, + divisor=4, + **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, + divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layers(activation) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class GhostModule(nn.Module): + + def __init__(self, + inp, + oup, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + activation='ReLU'): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False), + nn.BatchNorm2d(init_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out + + +class GhostBottleneck(nn.Module): + """Ghost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + activation='ReLU', + se_ratio=0.0, + ): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + self.ghost2 = GhostModule(mid_chs, out_chs, activation=None) + + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + x = self.ghost1(x) + + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + + if self.se is not None: + x = self.se(x) + + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + +class GhostBlocks(nn.Module): + """Stack of GhostBottleneck used in GhostPAN. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + activation (str): Name of activation function. Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + expand=1, + kernel_size=5, + num_blocks=1, + use_res=False, + activation='LeakyReLU', + ): + super(GhostBlocks, self).__init__() + self.use_res = use_res + if use_res: + self.reduce_conv = ConvModule( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + activation=activation, + ) + blocks = [] + for _ in range(num_blocks): + blocks.append( + GhostBottleneck( + in_channels, + int(out_channels * expand), + out_channels, + dw_kernel_size=kernel_size, + activation=activation, + )) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + out = self.blocks(x) + if self.use_res: + out = out + self.reduce_conv(x) + return out + + +class GhostPAN(nn.Module): + """Path Aggregation Network with Ghost block. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + num_extra_level (int): Number of extra conv layers for more feature levels. + Default: 0. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(scale_factor=2, mode='nearest')` + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + activation (str): Activation layer name. + Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + use_depthwise=False, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=0, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU', + ): + super(GhostPAN, self).__init__() + assert num_extra_level >= 0 + assert num_blocks >= 1 + self.in_channels = in_channels + self.out_channels = out_channels + + conv = DepthwiseConvModule if use_depthwise else ConvModule + + # build top-down blocks + self.upsample = nn.Upsample(**upsample_cfg) + self.reduce_layers = nn.ModuleList() + for idx in range(len(in_channels)): + self.reduce_layers.append( + ConvModule( + in_channels[idx], + out_channels, + 1, + norm_cfg=norm_cfg, + activation=activation, + )) + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.bottom_up_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # extra layers + self.extra_lvl_in_conv = nn.ModuleList() + self.extra_lvl_out_conv = nn.ModuleList() + for i in range(num_extra_level): + self.extra_lvl_in_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.extra_lvl_out_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + Returns: + tuple[Tensor]: multi level features. + """ + assert len(inputs) == len(self.in_channels) + inputs = [ + reduce(input_x) + for input_x, reduce in zip(inputs, self.reduce_layers) + ] + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1)) + outs.append(out) + + # extra layers + for extra_in_layer, extra_out_layer in zip(self.extra_lvl_in_conv, + self.extra_lvl_out_conv): + outs.append(extra_in_layer(inputs[-1]) + extra_out_layer(outs[-1])) + + return tuple(outs) diff --git a/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py new file mode 100644 index 00000000..7f5b50ec --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py @@ -0,0 +1,427 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import nms + +from .utils import ConvModule, DepthwiseConvModule + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4) + return x + + +def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): + """Performs non-maximum suppression in a batched fashion. + Modified from https://github.com/pytorch/vision/blob + /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + Arguments: + boxes (torch.Tensor): boxes in shape (N, 4). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict): specify nms type and other parameters like iou_thr. + Possible keys includes the following. + - iou_thr (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. + Returns: + tuple: kept dets and indice. + """ + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + nms_cfg_.pop('type', 'nms') + split_thr = nms_cfg_.pop('split_thr', 10000) + if len(boxes_for_nms) < split_thr: + keep = nms(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + scores = scores[keep] + else: + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + for id in torch.unique(idxs): + mask = (idxs == id).nonzero(as_tuple=False).view(-1) + keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + + keep = total_mask.nonzero(as_tuple=False).view(-1) + keep = keep[scores[keep].argsort(descending=True)] + boxes = boxes[keep] + scores = scores[keep] + + return torch.cat([boxes, scores[:, None]], -1), keep + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + nms_cfg, + max_num=-1, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores[:, :-1] + + valid_mask = scores > score_thr + + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + return bboxes, labels + + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + return dets, labels[keep] + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def warp_boxes(boxes, M, width, height): + n = len(boxes) + if n: + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) + xy = xy @ M.T + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + +class NanoDetPlusHead(nn.Module): + """Detection head used in NanoDet-Plus. + + Args: + num_classes (int): Number of categories excluding the background + category. + loss (dict): Loss config. + input_channel (int): Number of channels of the input feature. + feat_channels (int): Number of channels of the feature. + Default: 96. + stacked_convs (int): Number of conv layers in the stacked convs. + Default: 2. + kernel_size (int): Size of the convolving kernel. Default: 5. + strides (list[int]): Strides of input multi-level feature maps. + Default: [8, 16, 32]. + conv_type (str): Type of the convolution. + Default: "DWConv". + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + reg_max (int): The maximal value of the discrete set. Default: 7. + activation (str): Type of activation function. Default: "LeakyReLU". + assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13). + """ + + def __init__(self, + num_classes, + input_channel, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13), + **kwargs): + super(NanoDetPlusHead, self).__init__() + self.num_classes = num_classes + self.in_channels = input_channel + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.kernel_size = kernel_size + self.strides = strides + self.reg_max = reg_max + self.activation = activation + self.ConvModule = ConvModule if conv_type == 'Conv' else DepthwiseConvModule + + self.norm_cfg = norm_cfg + self.distribution_project = Integral(self.reg_max) + + self._init_layers() + + def _init_layers(self): + self.cls_convs = nn.ModuleList() + for _ in self.strides: + cls_convs = self._buid_not_shared_head() + self.cls_convs.append(cls_convs) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels, + self.num_classes + 4 * (self.reg_max + 1), + 1, + padding=0, + ) for _ in self.strides + ]) + + def _buid_not_shared_head(self): + cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + self.ConvModule( + chn, + self.feat_channels, + self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None, + activation=self.activation, + )) + return cls_convs + + def forward(self, feats): + if torch.onnx.is_in_onnx_export(): + return self._forward_onnx(feats) + outputs = [] + for feat, cls_convs, gfl_cls in zip( + feats, + self.cls_convs, + self.gfl_cls, + ): + for conv in cls_convs: + feat = conv(feat) + output = gfl_cls(feat) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs + + def post_process(self, preds, meta): + """Prediction results post processing. Decode bboxes and rescale + to original image size. + Args: + preds (Tensor): Prediction output. + meta (dict): Meta info. + """ + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1) + result_list = self.get_bboxes(cls_scores, bbox_preds, meta) + det_results = {} + warp_matrixes = ( + meta['warp_matrix'] + if isinstance(meta['warp_matrix'], list) else meta['warp_matrix']) + img_heights = ( + meta['img_info']['height'].cpu().numpy() if isinstance( + meta['img_info']['height'], torch.Tensor) else + meta['img_info']['height']) + img_widths = ( + meta['img_info']['width'].cpu().numpy() if isinstance( + meta['img_info']['width'], torch.Tensor) else + meta['img_info']['width']) + img_ids = ( + meta['img_info']['id'].cpu().numpy() if isinstance( + meta['img_info']['id'], torch.Tensor) else + meta['img_info']['id']) + + for result, img_width, img_height, img_id, warp_matrix in zip( + result_list, img_widths, img_heights, img_ids, warp_matrixes): + det_result = {} + det_bboxes, det_labels = result + det_bboxes = det_bboxes.detach().cpu().numpy() + det_bboxes[:, :4] = warp_boxes(det_bboxes[:, :4], + np.linalg.inv(warp_matrix), + img_width, img_height) + classes = det_labels.detach().cpu().numpy() + for i in range(self.num_classes): + inds = classes == i + det_result[i] = np.concatenate( + [ + det_bboxes[inds, :4].astype(np.float32), + det_bboxes[inds, 4:5].astype(np.float32), + ], + axis=1, + ).tolist() + det_results[img_id] = det_result + return det_results + + def get_bboxes(self, cls_preds, reg_preds, img_metas): + """Decode the outputs to bboxes. + Args: + cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). + reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). + img_metas (dict): Dict of image info. + + Returns: + results_list (list[tuple]): List of detection bboxes and labels. + """ + device = cls_preds.device + b = cls_preds.shape[0] + input_height, input_width = img_metas['img'].shape[2:] + input_shape = (input_height, input_width) + + featmap_sizes = [(math.ceil(input_height / stride), + math.ceil(input_width) / stride) + for stride in self.strides] + mlvl_center_priors = [ + self.get_single_level_center_priors( + b, + featmap_sizes[i], + stride, + dtype=torch.float32, + device=device, + ) for i, stride in enumerate(self.strides) + ] + center_priors = torch.cat(mlvl_center_priors, dim=1) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., + 2, + None] + bboxes = distance2bbox( + center_priors[..., :2], dis_preds, max_shape=input_shape) + scores = cls_preds.sigmoid() + result_list = [] + for i in range(b): + score, bbox = scores[i], bboxes[i] + padding = score.new_zeros(score.shape[0], 1) + score = torch.cat([score, padding], dim=1) + results = multiclass_nms( + bbox, + score, + score_thr=0.05, + nms_cfg=dict(type='nms', iou_threshold=0.6), + max_num=100, + ) + result_list.append(results) + return result_list + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + """Generate centers of a single stage feature map. + Args: + batch_size (int): Number of images in one batch. + featmap_size (tuple[int]): height and width of the feature map + stride (int): down sample stride of the feature map + dtype (obj:`torch.dtype`): data type of the tensors + device (obj:`torch.device`): device of the tensors + Return: + priors (Tensor): center priors of a single level feature map. + """ + h, w = featmap_size + x_range = (torch.arange(w, dtype=dtype, device=device)) * stride + y_range = (torch.arange(h, dtype=dtype, device=device)) * stride + y, x = torch.meshgrid(y_range, x_range) + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + proiors = torch.stack([x, y, strides, strides], dim=-1) + return proiors.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py new file mode 100644 index 00000000..c1d0a52f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py @@ -0,0 +1,64 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .ghost_pan import GhostPAN +from .nanodet_plus_head import NanoDetPlusHead +from .shufflenetv2 import ShuffleNetV2 + + +class OneStageDetector(nn.Module): + + def __init__(self): + super(OneStageDetector, self).__init__() + self.backbone = ShuffleNetV2( + model_size='1.0x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='LeakyReLU', + pretrain=False) + self.fpn = GhostPAN( + in_channels=[116, 232, 464], + out_channels=96, + use_depthwise=True, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=1, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU') + self.head = NanoDetPlusHead( + num_classes=3, + input_channel=96, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32, 64], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13)) + self.epoch = 0 + + def forward(self, x): + x = self.backbone(x) + if hasattr(self, 'fpn'): + x = self.fpn(x) + if hasattr(self, 'head'): + x = self.head(x) + return x + + def inference(self, meta): + with torch.no_grad(): + torch.cuda.synchronize() + preds = self(meta['img']) + torch.cuda.synchronize() + results = self.head.post_process(preds, meta) + torch.cuda.synchronize() + return results diff --git a/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py new file mode 100644 index 00000000..7f4dfc2a --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py @@ -0,0 +1,182 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .utils import act_layers + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + x = x.view(batchsize, -1, height, width) + + return x + + +class ShuffleV2Block(nn.Module): + + def __init__(self, inp, oup, stride, activation='ReLU'): + super(ShuffleV2Block, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError('illegal stride value') + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features << 1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv( + inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d( + inp, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + self.depthwise_conv( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + ), + nn.BatchNorm2d(branch_features), + nn.Conv2d( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d( + i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + + def __init__( + self, + model_size='1.5x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='ReLU', + pretrain=True, + ): + super(ShuffleNetV2, self).__init__() + assert set(out_stages).issubset((2, 3, 4)) + + print('model size is ', model_size) + + self.stage_repeats = [4, 8, 4] + self.model_size = model_size + self.out_stages = out_stages + self.with_last_conv = with_last_conv + self.kernal_size = kernal_size + self.activation = activation + if model_size == '0.5x': + self._stage_out_channels = [24, 48, 96, 192, 1024] + elif model_size == '1.0x': + self._stage_out_channels = [24, 116, 232, 464, 1024] + elif model_size == '1.5x': + self._stage_out_channels = [24, 176, 352, 704, 1024] + elif model_size == '2.0x': + self._stage_out_channels = [24, 244, 488, 976, 2048] + else: + raise NotImplementedError + + # building first layer + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip( + stage_names, self.stage_repeats, self._stage_out_channels[1:]): + seq = [ + ShuffleV2Block( + input_channels, output_channels, 2, activation=activation) + ] + for i in range(repeats - 1): + seq.append( + ShuffleV2Block( + output_channels, + output_channels, + 1, + activation=activation)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + output_channels = self._stage_out_channels[-1] + if self.with_last_conv: + conv5 = nn.Sequential( + nn.Conv2d( + input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + self.stage4.add_module('conv5', conv5) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + output = [] + + for i in range(2, 5): + stage = getattr(self, 'stage{}'.format(i)) + x = stage(x) + if i in self.out_stages: + output.append(x) + return tuple(output) diff --git a/modelscope/models/cv/face_human_hand_detection/utils.py b/modelscope/models/cv/face_human_hand_detection/utils.py new file mode 100644 index 00000000..f989c164 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/utils.py @@ -0,0 +1,277 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +activations = { + 'ReLU': nn.ReLU, + 'LeakyReLU': nn.LeakyReLU, + 'ReLU6': nn.ReLU6, + 'SELU': nn.SELU, + 'ELU': nn.ELU, + 'GELU': nn.GELU, + 'PReLU': nn.PReLU, + 'SiLU': nn.SiLU, + 'HardSwish': nn.Hardswish, + 'Hardswish': nn.Hardswish, + None: nn.Identity, +} + + +def act_layers(name): + assert name in activations.keys() + if name == 'LeakyReLU': + return nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif name == 'GELU': + return nn.GELU() + elif name == 'PReLU': + return nn.PReLU() + else: + return activations[name](inplace=True) + + +norm_cfg = { + 'BN': ('bn', nn.BatchNorm2d), + 'SyncBN': ('bn', nn.SyncBatchNorm), + 'GN': ('gn', nn.GroupNorm), +} + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer + + Args: + cfg (dict): cfg should contain: + type (str): identify norm layer type. + layer args: args needed to instantiate a norm layer. + requires_grad (bool): [optional] whether stop gradient updates + num_features (int): number of channels from input. + postfix (int, str): appended into norm abbreviation to + create named layer. + + Returns: + name (str): abbreviation + postfix + layer (nn.Module): created norm layer + """ + assert isinstance(cfg, dict) and 'type' in cfg + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in norm_cfg: + raise KeyError('Unrecognized norm type {}'.format(layer_type)) + else: + abbr, norm_layer = norm_cfg[layer_type] + if norm_layer is None: + raise NotImplementedError + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +class ConvModule(nn.Module): + """A conv block that contains conv/norm/activation layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + activation (str): activation layer, "ReLU" by default. + inplace (bool): Whether to use inplace mode for activation. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + activation='ReLU', + inplace=True, + order=('conv', 'norm', 'act'), + ): + super(ConvModule, self).__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert activation is None or isinstance(activation, str) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == {'conv', 'norm', 'act'} + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = self.conv.padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_norm: + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + else: + self.norm_name = None + + if self.activation: + self.act = act_layers(self.activation) + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def forward(self, x, norm=True): + for layer in self.order: + if layer == 'conv': + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and self.activation: + x = self.act(x) + return x + + +class DepthwiseConvModule(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias='auto', + norm_cfg=dict(type='BN'), + activation='ReLU', + inplace=True, + order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'), + ): + super(DepthwiseConvModule, self).__init__() + assert activation is None or isinstance(activation, str) + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 6 + assert set(order) == { + 'depthwise', + 'dwnorm', + 'act', + 'pointwise', + 'pwnorm', + 'act', + } + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + ) + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias) + + self.in_channels = self.depthwise.in_channels + self.out_channels = self.pointwise.out_channels + self.kernel_size = self.depthwise.kernel_size + self.stride = self.depthwise.stride + self.padding = self.depthwise.padding + self.dilation = self.depthwise.dilation + self.transposed = self.depthwise.transposed + self.output_padding = self.depthwise.output_padding + + if self.with_norm: + _, self.dwnorm = build_norm_layer(norm_cfg, in_channels) + _, self.pwnorm = build_norm_layer(norm_cfg, out_channels) + + if self.activation: + self.act = act_layers(self.activation) + + def forward(self, x, norm=True): + for layer_name in self.order: + if layer_name != 'act': + layer = self.__getattr__(layer_name) + x = layer(x) + elif layer_name == 'act' and self.activation: + x = self.act(x) + return x diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 357afd07..52f3c47e 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -649,8 +649,17 @@ TASK_OUTPUTS = { # 'output': ['Done' / 'Decode_Error'] # } Tasks.video_inpainting: [OutputKeys.OUTPUT], + # { # 'output': ['bixin'] # } - Tasks.hand_static: [OutputKeys.OUTPUT] + Tasks.hand_static: [OutputKeys.OUTPUT], + + # { + # 'output': [ + # [2, 75, 287, 240, 510, 0.8335018754005432], + # [1, 127, 83, 332, 366, 0.9175254702568054], + # [0, 0, 0, 367, 639, 0.9693422317504883]] + # } + Tasks.face_human_hand_detection: [OutputKeys.OUTPUT], } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4f6873b0..a14b07a6 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -183,6 +183,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_video-inpainting'), Tasks.hand_static: (Pipelines.hand_static, 'damo/cv_mobileface_hand-static'), + Tasks.face_human_hand_detection: + (Pipelines.face_human_hand_detection, + 'damo/cv_nanodet_face-human-hand-detection'), } diff --git a/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py new file mode 100644 index 00000000..d9f214c9 --- /dev/null +++ b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py @@ -0,0 +1,42 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_human_hand_detection import det_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_human_hand_detection, + module_name=Pipelines.face_human_hand_detection) +class NanoDettForFaceHumanHandDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face-human-hand detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + result = det_infer.inference(self.model, self.device, + input['input_path']) + logger.info(result) + return {OutputKeys.OUTPUT: result} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index b19c0fce..ac6846e4 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -43,6 +43,7 @@ class CVTasks(object): text_driven_segmentation = 'text-driven-segmentation' shop_segmentation = 'shop-segmentation' hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_face_human_hand_detection.py b/tests/pipelines/test_face_human_hand_detection.py new file mode 100644 index 00000000..7aaa67e7 --- /dev/null +++ b/tests/pipelines/test_face_human_hand_detection.py @@ -0,0 +1,38 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class FaceHumanHandTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_nanodet_face-human-hand-detection' + self.input = { + 'input_path': 'data/test/images/face_human_hand_detection.jpg', + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + logger.info(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_human_hand_detection = pipeline( + Tasks.face_human_hand_detection, model=self.model_id) + self.pipeline_inference(face_human_hand_detection, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_human_hand_detection = pipeline(Tasks.face_human_hand_detection) + self.pipeline_inference(face_human_hand_detection, self.input) + + +if __name__ == '__main__': + unittest.main()