From f508be89183cc2d9047bbb6fcbe23685a239959d Mon Sep 17 00:00:00 2001 From: ly261666 Date: Sat, 3 Sep 2022 23:48:42 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=20=E6=96=B0=E5=A2=9ERetinaFace?= =?UTF-8?q?=E4=BA=BA=E8=84=B8=E6=A3=80=E6=B5=8B=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 新增人脸检测RetinaFace模型; 2. 完成Maas-cv CR标准自查 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9945188 --- data/test/images/retina_face_detection.jpg | 3 + modelscope/metainfo.py | 2 + .../cv/face_detection/retinaface/__init__.py | 0 .../cv/face_detection/retinaface/detection.py | 137 ++++++++++++++++ .../retinaface/models/__init__.py | 0 .../face_detection/retinaface/models/net.py | 149 ++++++++++++++++++ .../retinaface/models/retinaface.py | 145 +++++++++++++++++ .../cv/face_detection/retinaface/utils.py | 123 +++++++++++++++ modelscope/pipelines/base.py | 1 - .../cv/retina_face_detection_pipeline.py | 55 +++++++ tests/pipelines/test_retina_face_detection.py | 33 ++++ 11 files changed, 647 insertions(+), 1 deletion(-) create mode 100644 data/test/images/retina_face_detection.jpg create mode 100644 modelscope/models/cv/face_detection/retinaface/__init__.py create mode 100755 modelscope/models/cv/face_detection/retinaface/detection.py create mode 100755 modelscope/models/cv/face_detection/retinaface/models/__init__.py create mode 100755 modelscope/models/cv/face_detection/retinaface/models/net.py create mode 100755 modelscope/models/cv/face_detection/retinaface/models/retinaface.py create mode 100755 modelscope/models/cv/face_detection/retinaface/utils.py create mode 100644 modelscope/pipelines/cv/retina_face_detection_pipeline.py create mode 100644 tests/pipelines/test_retina_face_detection.py diff --git a/data/test/images/retina_face_detection.jpg b/data/test/images/retina_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/retina_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index b1bf9600..9638268c 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -32,6 +32,7 @@ class Models(object): vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' text_driven_segmentation = 'text-driven-segmentation' resnet50_bert = 'resnet50-bert' + retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' # EasyCV models @@ -118,6 +119,7 @@ class Pipelines(object): salient_detection = 'u2net-salient-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' + retina_face_detection = 'resnet50-face-detection-retinaface' live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels' diff --git a/modelscope/models/cv/face_detection/retinaface/__init__.py b/modelscope/models/cv/face_detection/retinaface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/retinaface/detection.py b/modelscope/models/cv/face_detection/retinaface/detection.py new file mode 100755 index 00000000..3dd31659 --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/detection.py @@ -0,0 +1,137 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .models.retinaface import RetinaFace +from .utils import PriorBox, decode, decode_landm, py_cpu_nms + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface) +class RetinaFaceDetection(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.cfg = Config.from_file( + model_path.replace(ModelFile.TORCH_MODEL_FILE, + ModelFile.CONFIGURATION))['models'] + self.net = RetinaFace(cfg=self.cfg) + self.load_model() + self.device = device + self.net = self.net.to(self.device) + + self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) + + def check_keys(self, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(self.net.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + assert len( + used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + def remove_prefix(self, state_dict, prefix): + new_state_dict = dict() + for k, v in state_dict.items(): + if k.startswith(prefix): + new_state_dict[k[len(prefix):]] = v + else: + new_state_dict[k] = v + return new_state_dict + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + if 'state_dict' in pretrained_dict.keys(): + pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], + 'module.') + else: + pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') + self.check_keys(pretrained_dict) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def forward(self, input): + img_raw = input['img'].cpu().numpy() + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + loc, conf, landms = self.net(img) # forward pass + del img + + confidence_threshold = 0.9 + nms_threshold = 0.4 + top_k = 5000 + keep_top_k = 750 + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + im_width, im_height, im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss diff --git a/modelscope/models/cv/face_detection/retinaface/models/__init__.py b/modelscope/models/cv/face_detection/retinaface/models/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/retinaface/models/net.py b/modelscope/models/cv/face_detection/retinaface/models/net.py new file mode 100755 index 00000000..3be7c4b9 --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/models/net.py @@ -0,0 +1,149 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +from torch.autograd import Variable + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + x = x.view(-1, 256) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/face_detection/retinaface/models/retinaface.py b/modelscope/models/cv/face_detection/retinaface/models/retinaface.py new file mode 100755 index 00000000..8d2001dd --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/models/retinaface.py @@ -0,0 +1,145 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +import torchvision.models.detection.backbone_utils as backbone_utils + +from .net import FPN, SSH, MobileNetV1 + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d( + inchannels, + self.num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, cfg=None): + """ + :param cfg: Network related settings. + """ + super(RetinaFace, self).__init__() + backbone = None + if cfg['name'] == 'Resnet50': + backbone = models.resnet50(pretrained=cfg['pretrain']) + else: + raise Exception('Invalid name') + + self.body = _utils.IntermediateLayerGetter(backbone, + cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead + + def forward(self, inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = torch.cat( + [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], + dim=1) + + output = (bbox_regressions, F.softmax(classifications, + dim=-1), ldm_regressions) + return output diff --git a/modelscope/models/cv/face_detection/retinaface/utils.py b/modelscope/models/cv/face_detection/retinaface/utils.py new file mode 100755 index 00000000..60c9e2dd --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/utils.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------- +# Modified from https://github.com/biubug6/Pytorch_Retinaface +# -------------------------------------------------------- + +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index c0f3cbd0..d4f9c6bf 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -2,7 +2,6 @@ import os.path as osp from abc import ABC, abstractmethod -from contextlib import contextmanager from threading import Lock from typing import Any, Dict, Generator, List, Mapping, Union diff --git a/modelscope/pipelines/cv/retina_face_detection_pipeline.py b/modelscope/pipelines/cv/retina_face_detection_pipeline.py new file mode 100644 index 00000000..20111c11 --- /dev/null +++ b/modelscope/pipelines/cv/retina_face_detection_pipeline.py @@ -0,0 +1,55 @@ +import os.path as osp +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection.retinaface import detection +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.retina_face_detection) +class RetinaFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = detection.RetinaFaceDetection( + model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0][:, :4].tolist() + scores = result[0][:, 4].tolist() + lms = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: lms, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_retina_face_detection.py b/tests/pipelines/test_retina_face_detection.py new file mode 100644 index 00000000..343e1c91 --- /dev/null +++ b/tests/pipelines/test_retina_face_detection.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.test_utils import test_level + + +class RetinaFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet50_face-detection_retinaface' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/retina_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main()