tingwei.gtw yingda.chen 3 years ago
parent
commit
4199af337e
14 changed files with 1597 additions and 1 deletions
  1. +3
    -0
      data/test/images/face_human_hand_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +20
    -0
      modelscope/models/cv/face_human_hand_detection/__init__.py
  4. +133
    -0
      modelscope/models/cv/face_human_hand_detection/det_infer.py
  5. +395
    -0
      modelscope/models/cv/face_human_hand_detection/ghost_pan.py
  6. +427
    -0
      modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py
  7. +64
    -0
      modelscope/models/cv/face_human_hand_detection/one_stage_detector.py
  8. +182
    -0
      modelscope/models/cv/face_human_hand_detection/shufflenetv2.py
  9. +277
    -0
      modelscope/models/cv/face_human_hand_detection/utils.py
  10. +10
    -1
      modelscope/outputs.py
  11. +3
    -0
      modelscope/pipelines/builder.py
  12. +42
    -0
      modelscope/pipelines/cv/face_human_hand_detection_pipeline.py
  13. +1
    -0
      modelscope/utils/constant.py
  14. +38
    -0
      tests/pipelines/test_face_human_hand_detection.py

+ 3
- 0
data/test/images/face_human_hand_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487
size 23889

+ 2
- 0
modelscope/metainfo.py View File

@@ -40,6 +40,7 @@ class Models(object):
ulfd = 'ulfd' ulfd = 'ulfd'
video_inpainting = 'video-inpainting' video_inpainting = 'video-inpainting'
hand_static = 'hand-static' hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'


# EasyCV models # EasyCV models
yolox = 'YOLOX' yolox = 'YOLOX'
@@ -181,6 +182,7 @@ class Pipelines(object):
video_inpainting = 'video-inpainting' video_inpainting = 'video-inpainting'
pst_action_recognition = 'patchshift-action-recognition' pst_action_recognition = 'patchshift-action-recognition'
hand_static = 'hand-static' hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'


# nlp tasks # nlp tasks
sentence_similarity = 'sentence-similarity' sentence_similarity = 'sentence-similarity'


+ 20
- 0
modelscope/models/cv/face_human_hand_detection/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .det_infer import NanoDetForFaceHumanHandDetection

else:
_import_structure = {'det_infer': ['NanoDetForFaceHumanHandDetection']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 133
- 0
modelscope/models/cv/face_human_hand_detection/det_infer.py View File

@@ -0,0 +1,133 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.

import cv2
import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .one_stage_detector import OneStageDetector

logger = get_logger()


def load_model_weight(model_dir, device):
checkpoint = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device)
state_dict = checkpoint['state_dict'].copy()
for k in checkpoint['state_dict']:
if k.startswith('avg_model.'):
v = state_dict.pop(k)
state_dict[k[4:]] = v

return state_dict


@MODELS.register_module(
Tasks.face_human_hand_detection,
module_name=Models.face_human_hand_detection)
class NanoDetForFaceHumanHandDetection(TorchModel):

def __init__(self, model_dir, device_id=0, *args, **kwargs):

super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)

self.model = OneStageDetector()
if torch.cuda.is_available():
self.device = 'cuda'
logger.info('Use GPU ')
else:
self.device = 'cpu'
logger.info('Use CPU')

self.state_dict = load_model_weight(model_dir, self.device)
self.model.load_state_dict(self.state_dict, strict=False)
self.model.eval()
self.model.to(self.device)

def forward(self, x):
pred_result = self.model.inference(x)
return pred_result


def naive_collate(batch):
elem = batch[0]
if isinstance(elem, dict):
return {key: naive_collate([d[key] for d in batch]) for key in elem}
else:
return batch


def get_resize_matrix(raw_shape, dst_shape):

r_w, r_h = raw_shape
d_w, d_h = dst_shape
Rs = np.eye(3)

Rs[0, 0] *= d_w / r_w
Rs[1, 1] *= d_h / r_h
return Rs


def color_aug_and_norm(meta, mean, std):
img = meta['img'].astype(np.float32) / 255
mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) / 255
std = np.array(std, dtype=np.float32).reshape(1, 1, 3) / 255
img = (img - mean) / std
meta['img'] = img
return meta


def img_process(meta, mean, std):
raw_img = meta['img']
height = raw_img.shape[0]
width = raw_img.shape[1]
dst_shape = [320, 320]
M = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
ResizeM = get_resize_matrix((width, height), dst_shape)
M = ResizeM @ M
img = cv2.warpPerspective(raw_img, M, dsize=tuple(dst_shape))
meta['img'] = img
meta['warp_matrix'] = M
meta = color_aug_and_norm(meta, mean, std)
return meta


def overlay_bbox_cv(dets, class_names, score_thresh):
all_box = []
for label in dets:
for bbox in dets[label]:
score = bbox[-1]
if score > score_thresh:
x0, y0, x1, y1 = [int(i) for i in bbox[:4]]
all_box.append([label, x0, y0, x1, y1, score])
all_box.sort(key=lambda v: v[5])
return all_box


mean = [103.53, 116.28, 123.675]
std = [57.375, 57.12, 58.395]
class_names = ['person', 'face', 'hand']


def inference(model, device, img_path):
img_info = {'id': 0}
img = cv2.imread(img_path)
height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
meta = dict(img_info=img_info, raw_img=img, img=img)

meta = img_process(meta, mean, std)
meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).to(device)
meta = naive_collate([meta])
meta['img'] = (meta['img'][0]).reshape(1, 3, 320, 320)
with torch.no_grad():
res = model(meta)
result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35)
return result

+ 395
- 0
modelscope/models/cv/face_human_hand_detection/ghost_pan.py View File

@@ -0,0 +1,395 @@
# The implementation here is modified based on nanodet,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet

import math

import torch
import torch.nn as nn

from .utils import ConvModule, DepthwiseConvModule, act_layers


def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


def hard_sigmoid(x, inplace: bool = False):
if inplace:
return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0)
else:
return F.relu6(x + 3.0) / 6.0


class SqueezeExcite(nn.Module):

def __init__(self,
in_chs,
se_ratio=0.25,
reduced_base_chs=None,
activation='ReLU',
gate_fn=hard_sigmoid,
divisor=4,
**_):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio,
divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layers(activation)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)

def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x


class GhostModule(nn.Module):

def __init__(self,
inp,
oup,
kernel_size=1,
ratio=2,
dw_size=3,
stride=1,
activation='ReLU'):
super(GhostModule, self).__init__()
self.oup = oup
init_channels = math.ceil(oup / ratio)
new_channels = init_channels * (ratio - 1)

self.primary_conv = nn.Sequential(
nn.Conv2d(
inp,
init_channels,
kernel_size,
stride,
kernel_size // 2,
bias=False),
nn.BatchNorm2d(init_channels),
act_layers(activation) if activation else nn.Sequential(),
)

self.cheap_operation = nn.Sequential(
nn.Conv2d(
init_channels,
new_channels,
dw_size,
1,
dw_size // 2,
groups=init_channels,
bias=False,
),
nn.BatchNorm2d(new_channels),
act_layers(activation) if activation else nn.Sequential(),
)

def forward(self, x):
x1 = self.primary_conv(x)
x2 = self.cheap_operation(x1)
out = torch.cat([x1, x2], dim=1)
return out


class GhostBottleneck(nn.Module):
"""Ghost bottleneck w/ optional SE"""

def __init__(
self,
in_chs,
mid_chs,
out_chs,
dw_kernel_size=3,
stride=1,
activation='ReLU',
se_ratio=0.0,
):
super(GhostBottleneck, self).__init__()
has_se = se_ratio is not None and se_ratio > 0.0
self.stride = stride

# Point-wise expansion
self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation)

# Depth-wise convolution
if self.stride > 1:
self.conv_dw = nn.Conv2d(
mid_chs,
mid_chs,
dw_kernel_size,
stride=stride,
padding=(dw_kernel_size - 1) // 2,
groups=mid_chs,
bias=False,
)
self.bn_dw = nn.BatchNorm2d(mid_chs)

if has_se:
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
else:
self.se = None

self.ghost2 = GhostModule(mid_chs, out_chs, activation=None)

if in_chs == out_chs and self.stride == 1:
self.shortcut = nn.Sequential()
else:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_chs,
in_chs,
dw_kernel_size,
stride=stride,
padding=(dw_kernel_size - 1) // 2,
groups=in_chs,
bias=False,
),
nn.BatchNorm2d(in_chs),
nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_chs),
)

def forward(self, x):
residual = x

x = self.ghost1(x)

if self.stride > 1:
x = self.conv_dw(x)
x = self.bn_dw(x)

if self.se is not None:
x = self.se(x)

x = self.ghost2(x)

x += self.shortcut(residual)
return x


class GhostBlocks(nn.Module):
"""Stack of GhostBottleneck used in GhostPAN.

Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
expand (int): Expand ratio of GhostBottleneck. Default: 1.
kernel_size (int): Kernel size of depthwise convolution. Default: 5.
num_blocks (int): Number of GhostBottlecneck blocks. Default: 1.
use_res (bool): Whether to use residual connection. Default: False.
activation (str): Name of activation function. Default: LeakyReLU.
"""

def __init__(
self,
in_channels,
out_channels,
expand=1,
kernel_size=5,
num_blocks=1,
use_res=False,
activation='LeakyReLU',
):
super(GhostBlocks, self).__init__()
self.use_res = use_res
if use_res:
self.reduce_conv = ConvModule(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
activation=activation,
)
blocks = []
for _ in range(num_blocks):
blocks.append(
GhostBottleneck(
in_channels,
int(out_channels * expand),
out_channels,
dw_kernel_size=kernel_size,
activation=activation,
))
self.blocks = nn.Sequential(*blocks)

def forward(self, x):
out = self.blocks(x)
if self.use_res:
out = out + self.reduce_conv(x)
return out


class GhostPAN(nn.Module):
"""Path Aggregation Network with Ghost block.

Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: False
kernel_size (int): Kernel size of depthwise convolution. Default: 5.
expand (int): Expand ratio of GhostBottleneck. Default: 1.
num_blocks (int): Number of GhostBottlecneck blocks. Default: 1.
use_res (bool): Whether to use residual connection. Default: False.
num_extra_level (int): Number of extra conv layers for more feature levels.
Default: 0.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(scale_factor=2, mode='nearest')`
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
activation (str): Activation layer name.
Default: LeakyReLU.
"""

def __init__(
self,
in_channels,
out_channels,
use_depthwise=False,
kernel_size=5,
expand=1,
num_blocks=1,
use_res=False,
num_extra_level=0,
upsample_cfg=dict(scale_factor=2, mode='bilinear'),
norm_cfg=dict(type='BN'),
activation='LeakyReLU',
):
super(GhostPAN, self).__init__()
assert num_extra_level >= 0
assert num_blocks >= 1
self.in_channels = in_channels
self.out_channels = out_channels

conv = DepthwiseConvModule if use_depthwise else ConvModule

# build top-down blocks
self.upsample = nn.Upsample(**upsample_cfg)
self.reduce_layers = nn.ModuleList()
for idx in range(len(in_channels)):
self.reduce_layers.append(
ConvModule(
in_channels[idx],
out_channels,
1,
norm_cfg=norm_cfg,
activation=activation,
))
self.top_down_blocks = nn.ModuleList()
for idx in range(len(in_channels) - 1, 0, -1):
self.top_down_blocks.append(
GhostBlocks(
out_channels * 2,
out_channels,
expand,
kernel_size=kernel_size,
num_blocks=num_blocks,
use_res=use_res,
activation=activation,
))

# build bottom-up blocks
self.downsamples = nn.ModuleList()
self.bottom_up_blocks = nn.ModuleList()
for idx in range(len(in_channels) - 1):
self.downsamples.append(
conv(
out_channels,
out_channels,
kernel_size,
stride=2,
padding=kernel_size // 2,
norm_cfg=norm_cfg,
activation=activation,
))
self.bottom_up_blocks.append(
GhostBlocks(
out_channels * 2,
out_channels,
expand,
kernel_size=kernel_size,
num_blocks=num_blocks,
use_res=use_res,
activation=activation,
))

# extra layers
self.extra_lvl_in_conv = nn.ModuleList()
self.extra_lvl_out_conv = nn.ModuleList()
for i in range(num_extra_level):
self.extra_lvl_in_conv.append(
conv(
out_channels,
out_channels,
kernel_size,
stride=2,
padding=kernel_size // 2,
norm_cfg=norm_cfg,
activation=activation,
))
self.extra_lvl_out_conv.append(
conv(
out_channels,
out_channels,
kernel_size,
stride=2,
padding=kernel_size // 2,
norm_cfg=norm_cfg,
activation=activation,
))

def forward(self, inputs):
"""
Args:
inputs (tuple[Tensor]): input features.
Returns:
tuple[Tensor]: multi level features.
"""
assert len(inputs) == len(self.in_channels)
inputs = [
reduce(input_x)
for input_x, reduce in zip(inputs, self.reduce_layers)
]
# top-down path
inner_outs = [inputs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]

inner_outs[0] = feat_heigh

upsample_feat = self.upsample(feat_heigh)

inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
torch.cat([upsample_feat, feat_low], 1))
inner_outs.insert(0, inner_out)

# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsamples[idx](feat_low)
out = self.bottom_up_blocks[idx](
torch.cat([downsample_feat, feat_height], 1))
outs.append(out)

# extra layers
for extra_in_layer, extra_out_layer in zip(self.extra_lvl_in_conv,
self.extra_lvl_out_conv):
outs.append(extra_in_layer(inputs[-1]) + extra_out_layer(outs[-1]))

return tuple(outs)

+ 427
- 0
modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py View File

@@ -0,0 +1,427 @@
# The implementation here is modified based on nanodet,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet

import math

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import nms

from .utils import ConvModule, DepthwiseConvModule


class Integral(nn.Module):
"""A fixed layer for calculating integral result from distribution.
This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
P(y_i) denotes the softmax vector that represents the discrete distribution
y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
Args:
reg_max (int): The maximal value of the discrete set. Default: 16. You
may want to reset it according to your new dataset or related
settings.
"""

def __init__(self, reg_max=16):
super(Integral, self).__init__()
self.reg_max = reg_max
self.register_buffer('project',
torch.linspace(0, self.reg_max, self.reg_max + 1))

def forward(self, x):
"""Forward feature from the regression head to get integral result of
bounding box location.
Args:
x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
n is self.reg_max.
Returns:
x (Tensor): Integral result of box locations, i.e., distance
offsets from the box center in four directions, shape (N, 4).
"""
shape = x.size()
x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1)
x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4)
return x


def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
"""Performs non-maximum suppression in a batched fashion.
Modified from https://github.com/pytorch/vision/blob
/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
In order to perform NMS independently per class, we add an offset to all
the boxes. The offset is dependent only on the class idx, and is large
enough so that boxes from different classes do not overlap.
Arguments:
boxes (torch.Tensor): boxes in shape (N, 4).
scores (torch.Tensor): scores in shape (N, ).
idxs (torch.Tensor): each index value correspond to a bbox cluster,
and NMS will not be applied between elements of different idxs,
shape (N, ).
nms_cfg (dict): specify nms type and other parameters like iou_thr.
Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the
number of boxes is large (e.g., 200k). To avoid OOM during
training, the users could set `split_thr` to a small value.
If the number of boxes is greater than the threshold, it will
perform NMS on each group of boxes separately and sequentially.
Defaults to 10000.
class_agnostic (bool): if true, nms is class agnostic,
i.e. IoU thresholding happens over all boxes,
regardless of the predicted class.
Returns:
tuple: kept dets and indice.
"""
nms_cfg_ = nms_cfg.copy()
class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
if class_agnostic:
boxes_for_nms = boxes
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
nms_cfg_.pop('type', 'nms')
split_thr = nms_cfg_.pop('split_thr', 10000)
if len(boxes_for_nms) < split_thr:
keep = nms(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep]
scores = scores[keep]
else:
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
for id in torch.unique(idxs):
mask = (idxs == id).nonzero(as_tuple=False).view(-1)
keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_)
total_mask[mask[keep]] = True

keep = total_mask.nonzero(as_tuple=False).view(-1)
keep = keep[scores[keep].argsort(descending=True)]
boxes = boxes[keep]
scores = scores[keep]

return torch.cat([boxes, scores[:, None]], -1), keep


def multiclass_nms(multi_bboxes,
multi_scores,
score_thr,
nms_cfg,
max_num=-1,
score_factors=None):
"""NMS for multi-class bboxes.

Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
score_factors (Tensor): The factors multiplied to scores before
applying NMS

Returns:
tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \
are 0-based.
"""
num_classes = multi_scores.size(1) - 1
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores[:, :-1]

valid_mask = scores > score_thr

bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4)
if score_factors is not None:
scores = scores * score_factors[:, None]
scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero(as_tuple=False)[:, 1]

if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)

if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels

dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)

if max_num > 0:
dets = dets[:max_num]
keep = keep[:max_num]

return dets, labels[keep]


def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.

Args:
points (Tensor): Shape (n, 2), [x, y].
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom).
max_shape (tuple): Shape of the image.

Returns:
Tensor: Decoded bboxes.
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)


def warp_boxes(boxes, M, width, height):
n = len(boxes)
if n:
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
xy = xy @ M.T
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes


class NanoDetPlusHead(nn.Module):
"""Detection head used in NanoDet-Plus.

Args:
num_classes (int): Number of categories excluding the background
category.
loss (dict): Loss config.
input_channel (int): Number of channels of the input feature.
feat_channels (int): Number of channels of the feature.
Default: 96.
stacked_convs (int): Number of conv layers in the stacked convs.
Default: 2.
kernel_size (int): Size of the convolving kernel. Default: 5.
strides (list[int]): Strides of input multi-level feature maps.
Default: [8, 16, 32].
conv_type (str): Type of the convolution.
Default: "DWConv".
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
reg_max (int): The maximal value of the discrete set. Default: 7.
activation (str): Type of activation function. Default: "LeakyReLU".
assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13).
"""

def __init__(self,
num_classes,
input_channel,
feat_channels=96,
stacked_convs=2,
kernel_size=5,
strides=[8, 16, 32],
conv_type='DWConv',
norm_cfg=dict(type='BN'),
reg_max=7,
activation='LeakyReLU',
assigner_cfg=dict(topk=13),
**kwargs):
super(NanoDetPlusHead, self).__init__()
self.num_classes = num_classes
self.in_channels = input_channel
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.kernel_size = kernel_size
self.strides = strides
self.reg_max = reg_max
self.activation = activation
self.ConvModule = ConvModule if conv_type == 'Conv' else DepthwiseConvModule

self.norm_cfg = norm_cfg
self.distribution_project = Integral(self.reg_max)

self._init_layers()

def _init_layers(self):
self.cls_convs = nn.ModuleList()
for _ in self.strides:
cls_convs = self._buid_not_shared_head()
self.cls_convs.append(cls_convs)

self.gfl_cls = nn.ModuleList([
nn.Conv2d(
self.feat_channels,
self.num_classes + 4 * (self.reg_max + 1),
1,
padding=0,
) for _ in self.strides
])

def _buid_not_shared_head(self):
cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
cls_convs.append(
self.ConvModule(
chn,
self.feat_channels,
self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None,
activation=self.activation,
))
return cls_convs

def forward(self, feats):
if torch.onnx.is_in_onnx_export():
return self._forward_onnx(feats)
outputs = []
for feat, cls_convs, gfl_cls in zip(
feats,
self.cls_convs,
self.gfl_cls,
):
for conv in cls_convs:
feat = conv(feat)
output = gfl_cls(feat)
outputs.append(output.flatten(start_dim=2))
outputs = torch.cat(outputs, dim=2).permute(0, 2, 1)
return outputs

def post_process(self, preds, meta):
"""Prediction results post processing. Decode bboxes and rescale
to original image size.
Args:
preds (Tensor): Prediction output.
meta (dict): Meta info.
"""
cls_scores, bbox_preds = preds.split(
[self.num_classes, 4 * (self.reg_max + 1)], dim=-1)
result_list = self.get_bboxes(cls_scores, bbox_preds, meta)
det_results = {}
warp_matrixes = (
meta['warp_matrix']
if isinstance(meta['warp_matrix'], list) else meta['warp_matrix'])
img_heights = (
meta['img_info']['height'].cpu().numpy() if isinstance(
meta['img_info']['height'], torch.Tensor) else
meta['img_info']['height'])
img_widths = (
meta['img_info']['width'].cpu().numpy() if isinstance(
meta['img_info']['width'], torch.Tensor) else
meta['img_info']['width'])
img_ids = (
meta['img_info']['id'].cpu().numpy() if isinstance(
meta['img_info']['id'], torch.Tensor) else
meta['img_info']['id'])

for result, img_width, img_height, img_id, warp_matrix in zip(
result_list, img_widths, img_heights, img_ids, warp_matrixes):
det_result = {}
det_bboxes, det_labels = result
det_bboxes = det_bboxes.detach().cpu().numpy()
det_bboxes[:, :4] = warp_boxes(det_bboxes[:, :4],
np.linalg.inv(warp_matrix),
img_width, img_height)
classes = det_labels.detach().cpu().numpy()
for i in range(self.num_classes):
inds = classes == i
det_result[i] = np.concatenate(
[
det_bboxes[inds, :4].astype(np.float32),
det_bboxes[inds, 4:5].astype(np.float32),
],
axis=1,
).tolist()
det_results[img_id] = det_result
return det_results

def get_bboxes(self, cls_preds, reg_preds, img_metas):
"""Decode the outputs to bboxes.
Args:
cls_preds (Tensor): Shape (num_imgs, num_points, num_classes).
reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)).
img_metas (dict): Dict of image info.

Returns:
results_list (list[tuple]): List of detection bboxes and labels.
"""
device = cls_preds.device
b = cls_preds.shape[0]
input_height, input_width = img_metas['img'].shape[2:]
input_shape = (input_height, input_width)

featmap_sizes = [(math.ceil(input_height / stride),
math.ceil(input_width) / stride)
for stride in self.strides]
mlvl_center_priors = [
self.get_single_level_center_priors(
b,
featmap_sizes[i],
stride,
dtype=torch.float32,
device=device,
) for i, stride in enumerate(self.strides)
]
center_priors = torch.cat(mlvl_center_priors, dim=1)
dis_preds = self.distribution_project(reg_preds) * center_priors[...,
2,
None]
bboxes = distance2bbox(
center_priors[..., :2], dis_preds, max_shape=input_shape)
scores = cls_preds.sigmoid()
result_list = []
for i in range(b):
score, bbox = scores[i], bboxes[i]
padding = score.new_zeros(score.shape[0], 1)
score = torch.cat([score, padding], dim=1)
results = multiclass_nms(
bbox,
score,
score_thr=0.05,
nms_cfg=dict(type='nms', iou_threshold=0.6),
max_num=100,
)
result_list.append(results)
return result_list

def get_single_level_center_priors(self, batch_size, featmap_size, stride,
dtype, device):
"""Generate centers of a single stage feature map.
Args:
batch_size (int): Number of images in one batch.
featmap_size (tuple[int]): height and width of the feature map
stride (int): down sample stride of the feature map
dtype (obj:`torch.dtype`): data type of the tensors
device (obj:`torch.device`): device of the tensors
Return:
priors (Tensor): center priors of a single level feature map.
"""
h, w = featmap_size
x_range = (torch.arange(w, dtype=dtype, device=device)) * stride
y_range = (torch.arange(h, dtype=dtype, device=device)) * stride
y, x = torch.meshgrid(y_range, x_range)
y = y.flatten()
x = x.flatten()
strides = x.new_full((x.shape[0], ), stride)
proiors = torch.stack([x, y, strides, strides], dim=-1)
return proiors.unsqueeze(0).repeat(batch_size, 1, 1)

+ 64
- 0
modelscope/models/cv/face_human_hand_detection/one_stage_detector.py View File

@@ -0,0 +1,64 @@
# The implementation here is modified based on nanodet,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet

import torch
import torch.nn as nn

from .ghost_pan import GhostPAN
from .nanodet_plus_head import NanoDetPlusHead
from .shufflenetv2 import ShuffleNetV2


class OneStageDetector(nn.Module):

def __init__(self):
super(OneStageDetector, self).__init__()
self.backbone = ShuffleNetV2(
model_size='1.0x',
out_stages=(2, 3, 4),
with_last_conv=False,
kernal_size=3,
activation='LeakyReLU',
pretrain=False)
self.fpn = GhostPAN(
in_channels=[116, 232, 464],
out_channels=96,
use_depthwise=True,
kernel_size=5,
expand=1,
num_blocks=1,
use_res=False,
num_extra_level=1,
upsample_cfg=dict(scale_factor=2, mode='bilinear'),
norm_cfg=dict(type='BN'),
activation='LeakyReLU')
self.head = NanoDetPlusHead(
num_classes=3,
input_channel=96,
feat_channels=96,
stacked_convs=2,
kernel_size=5,
strides=[8, 16, 32, 64],
conv_type='DWConv',
norm_cfg=dict(type='BN'),
reg_max=7,
activation='LeakyReLU',
assigner_cfg=dict(topk=13))
self.epoch = 0

def forward(self, x):
x = self.backbone(x)
if hasattr(self, 'fpn'):
x = self.fpn(x)
if hasattr(self, 'head'):
x = self.head(x)
return x

def inference(self, meta):
with torch.no_grad():
torch.cuda.synchronize()
preds = self(meta['img'])
torch.cuda.synchronize()
results = self.head.post_process(preds, meta)
torch.cuda.synchronize()
return results

+ 182
- 0
modelscope/models/cv/face_human_hand_detection/shufflenetv2.py View File

@@ -0,0 +1,182 @@
# The implementation here is modified based on nanodet,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet

import torch
import torch.nn as nn

from .utils import act_layers


def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups

x = x.view(batchsize, groups, channels_per_group, height, width)

x = torch.transpose(x, 1, 2).contiguous()

x = x.view(batchsize, -1, height, width)

return x


class ShuffleV2Block(nn.Module):

def __init__(self, inp, oup, stride, activation='ReLU'):
super(ShuffleV2Block, self).__init__()

if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride

branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)

if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)
else:
self.branch1 = nn.Sequential()

self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
act_layers(activation),
)

@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(
i, o, kernel_size, stride, padding, bias=bias, groups=i)

def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

out = channel_shuffle(out, 2)

return out


class ShuffleNetV2(nn.Module):

def __init__(
self,
model_size='1.5x',
out_stages=(2, 3, 4),
with_last_conv=False,
kernal_size=3,
activation='ReLU',
pretrain=True,
):
super(ShuffleNetV2, self).__init__()
assert set(out_stages).issubset((2, 3, 4))

print('model size is ', model_size)

self.stage_repeats = [4, 8, 4]
self.model_size = model_size
self.out_stages = out_stages
self.with_last_conv = with_last_conv
self.kernal_size = kernal_size
self.activation = activation
if model_size == '0.5x':
self._stage_out_channels = [24, 48, 96, 192, 1024]
elif model_size == '1.0x':
self._stage_out_channels = [24, 116, 232, 464, 1024]
elif model_size == '1.5x':
self._stage_out_channels = [24, 176, 352, 704, 1024]
elif model_size == '2.0x':
self._stage_out_channels = [24, 244, 488, 976, 2048]
else:
raise NotImplementedError

# building first layer
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, self.stage_repeats, self._stage_out_channels[1:]):
seq = [
ShuffleV2Block(
input_channels, output_channels, 2, activation=activation)
]
for i in range(repeats - 1):
seq.append(
ShuffleV2Block(
output_channels,
output_channels,
1,
activation=activation))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
if self.with_last_conv:
conv5 = nn.Sequential(
nn.Conv2d(
input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
act_layers(activation),
)
self.stage4.add_module('conv5', conv5)

def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
output = []

for i in range(2, 5):
stage = getattr(self, 'stage{}'.format(i))
x = stage(x)
if i in self.out_stages:
output.append(x)
return tuple(output)

+ 277
- 0
modelscope/models/cv/face_human_hand_detection/utils.py View File

@@ -0,0 +1,277 @@
# The implementation here is modified based on nanodet,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet

import torch
import torch.nn as nn

activations = {
'ReLU': nn.ReLU,
'LeakyReLU': nn.LeakyReLU,
'ReLU6': nn.ReLU6,
'SELU': nn.SELU,
'ELU': nn.ELU,
'GELU': nn.GELU,
'PReLU': nn.PReLU,
'SiLU': nn.SiLU,
'HardSwish': nn.Hardswish,
'Hardswish': nn.Hardswish,
None: nn.Identity,
}


def act_layers(name):
assert name in activations.keys()
if name == 'LeakyReLU':
return nn.LeakyReLU(negative_slope=0.1, inplace=True)
elif name == 'GELU':
return nn.GELU()
elif name == 'PReLU':
return nn.PReLU()
else:
return activations[name](inplace=True)


norm_cfg = {
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', nn.SyncBatchNorm),
'GN': ('gn', nn.GroupNorm),
}


def build_norm_layer(cfg, num_features, postfix=''):
"""Build normalization layer

Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
requires_grad (bool): [optional] whether stop gradient updates
num_features (int): number of channels from input.
postfix (int, str): appended into norm abbreviation to
create named layer.

Returns:
name (str): abbreviation + postfix
layer (nn.Module): created norm layer
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()

layer_type = cfg_.pop('type')
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
else:
abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError

assert isinstance(postfix, (int, str))
name = abbr + str(postfix)

requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)

for param in layer.parameters():
param.requires_grad = requires_grad

return name, layer


class ConvModule(nn.Module):
"""A conv block that contains conv/norm/activation layers.

Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
conv_cfg (dict): Config dict for convolution layer.
norm_cfg (dict): Config dict for normalization layer.
activation (str): activation layer, "ReLU" by default.
inplace (bool): Whether to use inplace mode for activation.
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Examples are
("conv", "norm", "act") and ("act", "conv", "norm").
"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias='auto',
conv_cfg=None,
norm_cfg=None,
activation='ReLU',
inplace=True,
order=('conv', 'norm', 'act'),
):
super(ConvModule, self).__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
assert activation is None or isinstance(activation, str)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.activation = activation
self.inplace = inplace
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == {'conv', 'norm', 'act'}

self.with_norm = norm_cfg is not None
if bias == 'auto':
bias = False if self.with_norm else True
self.with_bias = bias

if self.with_norm and self.with_bias:
warnings.warn('ConvModule has norm and bias at the same time')

self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups

if self.with_norm:
if order.index('norm') > order.index('conv'):
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm)
else:
self.norm_name = None

if self.activation:
self.act = act_layers(self.activation)

@property
def norm(self):
if self.norm_name:
return getattr(self, self.norm_name)
else:
return None

def forward(self, x, norm=True):
for layer in self.order:
if layer == 'conv':
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
elif layer == 'act' and self.activation:
x = self.act(x)
return x


class DepthwiseConvModule(nn.Module):

def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias='auto',
norm_cfg=dict(type='BN'),
activation='ReLU',
inplace=True,
order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'),
):
super(DepthwiseConvModule, self).__init__()
assert activation is None or isinstance(activation, str)
self.activation = activation
self.inplace = inplace
self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 6
assert set(order) == {
'depthwise',
'dwnorm',
'act',
'pointwise',
'pwnorm',
'act',
}

self.with_norm = norm_cfg is not None
if bias == 'auto':
bias = False if self.with_norm else True
self.with_bias = bias

if self.with_norm and self.with_bias:
warnings.warn('ConvModule has norm and bias at the same time')

self.depthwise = nn.Conv2d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=bias,
)
self.pointwise = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias)

self.in_channels = self.depthwise.in_channels
self.out_channels = self.pointwise.out_channels
self.kernel_size = self.depthwise.kernel_size
self.stride = self.depthwise.stride
self.padding = self.depthwise.padding
self.dilation = self.depthwise.dilation
self.transposed = self.depthwise.transposed
self.output_padding = self.depthwise.output_padding

if self.with_norm:
_, self.dwnorm = build_norm_layer(norm_cfg, in_channels)
_, self.pwnorm = build_norm_layer(norm_cfg, out_channels)

if self.activation:
self.act = act_layers(self.activation)

def forward(self, x, norm=True):
for layer_name in self.order:
if layer_name != 'act':
layer = self.__getattr__(layer_name)
x = layer(x)
elif layer_name == 'act' and self.activation:
x = self.act(x)
return x

+ 10
- 1
modelscope/outputs.py View File

@@ -649,8 +649,17 @@ TASK_OUTPUTS = {
# 'output': ['Done' / 'Decode_Error'] # 'output': ['Done' / 'Decode_Error']
# } # }
Tasks.video_inpainting: [OutputKeys.OUTPUT], Tasks.video_inpainting: [OutputKeys.OUTPUT],

# { # {
# 'output': ['bixin'] # 'output': ['bixin']
# } # }
Tasks.hand_static: [OutputKeys.OUTPUT]
Tasks.hand_static: [OutputKeys.OUTPUT],

# {
# 'output': [
# [2, 75, 287, 240, 510, 0.8335018754005432],
# [1, 127, 83, 332, 366, 0.9175254702568054],
# [0, 0, 0, 367, 639, 0.9693422317504883]]
# }
Tasks.face_human_hand_detection: [OutputKeys.OUTPUT],
} }

+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -183,6 +183,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_video-inpainting'), 'damo/cv_video-inpainting'),
Tasks.hand_static: (Pipelines.hand_static, Tasks.hand_static: (Pipelines.hand_static,
'damo/cv_mobileface_hand-static'), 'damo/cv_mobileface_hand-static'),
Tasks.face_human_hand_detection:
(Pipelines.face_human_hand_detection,
'damo/cv_nanodet_face-human-hand-detection'),
} }






+ 42
- 0
modelscope/pipelines/cv/face_human_hand_detection_pipeline.py View File

@@ -0,0 +1,42 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.

from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_human_hand_detection import det_infer
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_human_hand_detection,
module_name=Pipelines.face_human_hand_detection)
class NanoDettForFaceHumanHandDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create face-human-hand detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""

super().__init__(model=model, **kwargs)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

result = det_infer.inference(self.model, self.device,
input['input_path'])
logger.info(result)
return {OutputKeys.OUTPUT: result}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -43,6 +43,7 @@ class CVTasks(object):
text_driven_segmentation = 'text-driven-segmentation' text_driven_segmentation = 'text-driven-segmentation'
shop_segmentation = 'shop-segmentation' shop_segmentation = 'shop-segmentation'
hand_static = 'hand-static' hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'


# image editing # image editing
skin_retouching = 'skin-retouching' skin_retouching = 'skin-retouching'


+ 38
- 0
tests/pipelines/test_face_human_hand_detection.py View File

@@ -0,0 +1,38 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


class FaceHumanHandTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_nanodet_face-human-hand-detection'
self.input = {
'input_path': 'data/test/images/face_human_hand_detection.jpg',
}

def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input)
logger.info(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_human_hand_detection = pipeline(
Tasks.face_human_hand_detection, model=self.model_id)
self.pipeline_inference(face_human_hand_detection, self.input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
face_human_hand_detection = pipeline(Tasks.face_human_hand_detection)
self.pipeline_inference(face_human_hand_detection, self.input)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save