Browse Source

[to #42322933] 新增Mtcnn人脸检测器

1. 完成Maas-cv CR标准 自查
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9951519

    * [to #42322933] 新增Mtcnn人脸检测器
master
ly261666 yingda.chen 3 years ago
parent
commit
c12957a9eb
12 changed files with 756 additions and 2 deletions
  1. +3
    -0
      data/test/images/mtcnn_face_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +4
    -1
      modelscope/models/cv/face_detection/__init__.py
  4. +1
    -0
      modelscope/models/cv/face_detection/mtcnn/__init__.py
  5. +0
    -0
      modelscope/models/cv/face_detection/mtcnn/models/__init__.py
  6. +240
    -0
      modelscope/models/cv/face_detection/mtcnn/models/box_utils.py
  7. +149
    -0
      modelscope/models/cv/face_detection/mtcnn/models/detector.py
  8. +100
    -0
      modelscope/models/cv/face_detection/mtcnn/models/first_stage.py
  9. +160
    -0
      modelscope/models/cv/face_detection/mtcnn/models/get_nets.py
  10. +3
    -1
      modelscope/pipelines/cv/__init__.py
  11. +56
    -0
      modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py
  12. +38
    -0
      tests/pipelines/test_mtcnn_face_detection.py

+ 3
- 0
data/test/images/mtcnn_face_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
size 87228

+ 2
- 0
modelscope/metainfo.py View File

@@ -35,6 +35,7 @@ class Models(object):
fer = 'fer' fer = 'fer'
retinaface = 'retinaface' retinaface = 'retinaface'
shop_segmentation = 'shop-segmentation' shop_segmentation = 'shop-segmentation'
mtcnn = 'mtcnn'
ulfd = 'ulfd' ulfd = 'ulfd'


# EasyCV models # EasyCV models
@@ -127,6 +128,7 @@ class Pipelines(object):
ulfd_face_detection = 'manual-face-detection-ulfd' ulfd_face_detection = 'manual-face-detection-ulfd'
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
retina_face_detection = 'resnet50-face-detection-retinaface' retina_face_detection = 'resnet50-face-detection-retinaface'
mtcnn_face_detection = 'manual-face-detection-mtcnn'
live_category = 'live-category' live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels' general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels'


+ 4
- 1
modelscope/models/cv/face_detection/__init__.py View File

@@ -4,12 +4,15 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule from modelscope.utils.import_utils import LazyImportModule


if TYPE_CHECKING: if TYPE_CHECKING:
from .mtcnn import MtcnnFaceDetector
from .retinaface import RetinaFaceDetection from .retinaface import RetinaFaceDetection
from .ulfd_slim import UlfdFaceDetector from .ulfd_slim import UlfdFaceDetector

else: else:
_import_structure = { _import_structure = {
'ulfd_slim': ['UlfdFaceDetector'], 'ulfd_slim': ['UlfdFaceDetector'],
'retinaface': ['RetinaFaceDetection']
'retinaface': ['RetinaFaceDetection'],
'mtcnn': ['MtcnnFaceDetector']
} }


import sys import sys


+ 1
- 0
modelscope/models/cv/face_detection/mtcnn/__init__.py View File

@@ -0,0 +1 @@
from .models.detector import MtcnnFaceDetector

+ 0
- 0
modelscope/models/cv/face_detection/mtcnn/models/__init__.py View File


+ 240
- 0
modelscope/models/cv/face_detection/mtcnn/models/box_utils.py View File

@@ -0,0 +1,240 @@
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
import numpy as np
from PIL import Image


def nms(boxes, overlap_threshold=0.5, mode='union'):
"""Non-maximum suppression.

Arguments:
boxes: a float numpy array of shape [n, 5],
where each row is (xmin, ymin, xmax, ymax, score).
overlap_threshold: a float number.
mode: 'union' or 'min'.

Returns:
list with indices of the selected boxes
"""

# if there are no boxes, return the empty list
if len(boxes) == 0:
return []

# list of picked indices
pick = []

# grab the coordinates of the bounding boxes
x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]

area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
ids = np.argsort(score) # in increasing order

while len(ids) > 0:

# grab index of the largest value
last = len(ids) - 1
i = ids[last]
pick.append(i)

# compute intersections
# of the box with the largest score
# with the rest of boxes

# left top corner of intersection boxes
ix1 = np.maximum(x1[i], x1[ids[:last]])
iy1 = np.maximum(y1[i], y1[ids[:last]])

# right bottom corner of intersection boxes
ix2 = np.minimum(x2[i], x2[ids[:last]])
iy2 = np.minimum(y2[i], y2[ids[:last]])

# width and height of intersection boxes
w = np.maximum(0.0, ix2 - ix1 + 1.0)
h = np.maximum(0.0, iy2 - iy1 + 1.0)

# intersections' areas
inter = w * h
if mode == 'min':
overlap = inter / np.minimum(area[i], area[ids[:last]])
elif mode == 'union':
# intersection over union (IoU)
overlap = inter / (area[i] + area[ids[:last]] - inter)

# delete all boxes where overlap is too big
ids = np.delete(
ids,
np.concatenate([[last],
np.where(overlap > overlap_threshold)[0]]))

return pick


def convert_to_square(bboxes):
"""Convert bounding boxes to a square form.

Arguments:
bboxes: a float numpy array of shape [n, 5].

Returns:
a float numpy array of shape [n, 5],
squared bounding boxes.
"""

square_bboxes = np.zeros_like(bboxes)
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
h = y2 - y1 + 1.0
w = x2 - x1 + 1.0
max_side = np.maximum(h, w)
square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
return square_bboxes


def calibrate_box(bboxes, offsets):
"""Transform bounding boxes to be more like true bounding boxes.
'offsets' is one of the outputs of the nets.

Arguments:
bboxes: a float numpy array of shape [n, 5].
offsets: a float numpy array of shape [n, 4].

Returns:
a float numpy array of shape [n, 5].
"""
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
w = x2 - x1 + 1.0
h = y2 - y1 + 1.0
w = np.expand_dims(w, 1)
h = np.expand_dims(h, 1)

# this is what happening here:
# tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
# x1_true = x1 + tx1*w
# y1_true = y1 + ty1*h
# x2_true = x2 + tx2*w
# y2_true = y2 + ty2*h
# below is just more compact form of this

# are offsets always such that
# x1 < x2 and y1 < y2 ?

translation = np.hstack([w, h, w, h]) * offsets
bboxes[:, 0:4] = bboxes[:, 0:4] + translation
return bboxes


def get_image_boxes(bounding_boxes, img, size=24):
"""Cut out boxes from the image.

Arguments:
bounding_boxes: a float numpy array of shape [n, 5].
img: an instance of PIL.Image.
size: an integer, size of cutouts.

Returns:
a float numpy array of shape [n, 3, size, size].
"""

num_boxes = len(bounding_boxes)
width, height = img.size

[dy, edy, dx, edx, y, ey, x, ex, w,
h] = correct_bboxes(bounding_boxes, width, height)
img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')

for i in range(num_boxes):
img_box = np.zeros((h[i], w[i], 3), 'uint8')

img_array = np.asarray(img, 'uint8')
img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\
img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]

# resize
img_box = Image.fromarray(img_box)
img_box = img_box.resize((size, size), Image.BILINEAR)
img_box = np.asarray(img_box, 'float32')

img_boxes[i, :, :, :] = _preprocess(img_box)

return img_boxes


def correct_bboxes(bboxes, width, height):
"""Crop boxes that are too big and get coordinates
with respect to cutouts.

Arguments:
bboxes: a float numpy array of shape [n, 5],
where each row is (xmin, ymin, xmax, ymax, score).
width: a float number.
height: a float number.

Returns:
dy, dx, edy, edx: a int numpy arrays of shape [n],
coordinates of the boxes with respect to the cutouts.
y, x, ey, ex: a int numpy arrays of shape [n],
corrected ymin, xmin, ymax, xmax.
h, w: a int numpy arrays of shape [n],
just heights and widths of boxes.

in the following order:
[dy, edy, dx, edx, y, ey, x, ex, w, h].
"""

x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
num_boxes = bboxes.shape[0]

# 'e' stands for end
# (x, y) -> (ex, ey)
x, y, ex, ey = x1, y1, x2, y2

# we need to cut out a box from the image.
# (x, y, ex, ey) are corrected coordinates of the box
# in the image.
# (dx, dy, edx, edy) are coordinates of the box in the cutout
# from the image.
dx, dy = np.zeros((num_boxes, )), np.zeros((num_boxes, ))
edx, edy = w.copy() - 1.0, h.copy() - 1.0

# if box's bottom right corner is too far right
ind = np.where(ex > width - 1.0)[0]
edx[ind] = w[ind] + width - 2.0 - ex[ind]
ex[ind] = width - 1.0

# if box's bottom right corner is too low
ind = np.where(ey > height - 1.0)[0]
edy[ind] = h[ind] + height - 2.0 - ey[ind]
ey[ind] = height - 1.0

# if box's top left corner is too far left
ind = np.where(x < 0.0)[0]
dx[ind] = 0.0 - x[ind]
x[ind] = 0.0

# if box's top left corner is too high
ind = np.where(y < 0.0)[0]
dy[ind] = 0.0 - y[ind]
y[ind] = 0.0

return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
return_list = [i.astype('int32') for i in return_list]

return return_list


def _preprocess(img):
"""Preprocessing step before feeding the network.

Arguments:
img: a float numpy array of shape [h, w, c].

Returns:
a float numpy array of shape [1, c, h, w].
"""
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, 0)
img = (img - 127.5) * 0.0078125
return img

+ 149
- 0
modelscope/models/cv/face_detection/mtcnn/models/detector.py View File

@@ -0,0 +1,149 @@
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torch.autograd import Variable

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .box_utils import calibrate_box, convert_to_square, get_image_boxes, nms
from .first_stage import run_first_stage
from .get_nets import ONet, PNet, RNet


@MODELS.register_module(Tasks.face_detection, module_name=Models.mtcnn)
class MtcnnFaceDetector(TorchModel):

def __init__(self, model_path, device='cuda'):
super().__init__(model_path)
torch.set_grad_enabled(False)
cudnn.benchmark = True
self.model_path = model_path
self.device = device

self.pnet = PNet(model_path=os.path.join(self.model_path, 'pnet.npy'))
self.rnet = RNet(model_path=os.path.join(self.model_path, 'rnet.npy'))
self.onet = ONet(model_path=os.path.join(self.model_path, 'onet.npy'))

self.pnet = self.pnet.to(device)
self.rnet = self.rnet.to(device)
self.onet = self.onet.to(device)

def forward(self, input):
image = Image.fromarray(np.uint8(input['img'].cpu().numpy()))
pnet = self.pnet
rnet = self.rnet
onet = self.onet
onet.eval()

min_face_size = 20.0
thresholds = [0.7, 0.8, 0.9]
nms_thresholds = [0.7, 0.7, 0.7]

# BUILD AN IMAGE PYRAMID
width, height = image.size
min_length = min(height, width)

min_detection_size = 12
factor = 0.707 # sqrt(0.5)

# scales for scaling the image
scales = []

m = min_detection_size / min_face_size
min_length *= m

factor_count = 0
while min_length > min_detection_size:
scales.append(m * factor**factor_count)
min_length *= factor
factor_count += 1

# STAGE 1

# it will be returned
bounding_boxes = []

# run P-Net on different scales
for s in scales:
boxes = run_first_stage(
image,
pnet,
scale=s,
threshold=thresholds[0],
device=self.device)
bounding_boxes.append(boxes)

# collect boxes (and offsets, and scores) from different scales
bounding_boxes = [i for i in bounding_boxes if i is not None]
bounding_boxes = np.vstack(bounding_boxes)

keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
bounding_boxes = bounding_boxes[keep]

# use offsets predicted by pnet to transform bounding boxes
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5],
bounding_boxes[:, 5:])
# shape [n_boxes, 5]

bounding_boxes = convert_to_square(bounding_boxes)
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])

# STAGE 2

img_boxes = get_image_boxes(bounding_boxes, image, size=24)
img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
output = rnet(img_boxes.to(self.device))
offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]

keep = np.where(probs[:, 1] > thresholds[1])[0]
bounding_boxes = bounding_boxes[keep]
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, ))
offsets = offsets[keep]

keep = nms(bounding_boxes, nms_thresholds[1])
bounding_boxes = bounding_boxes[keep]
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
bounding_boxes = convert_to_square(bounding_boxes)
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])

# STAGE 3

img_boxes = get_image_boxes(bounding_boxes, image, size=48)
if len(img_boxes) == 0:
return [], []
img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
output = onet(img_boxes.to(self.device))
landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]

keep = np.where(probs[:, 1] > thresholds[2])[0]
bounding_boxes = bounding_boxes[keep]
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, ))
offsets = offsets[keep]
landmarks = landmarks[keep]

# compute landmark points
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
landmarks[:, 0:5] = np.expand_dims(
xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
landmarks[:, 5:10] = np.expand_dims(
ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]

bounding_boxes = calibrate_box(bounding_boxes, offsets)
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
bounding_boxes = bounding_boxes[keep]
landmarks = landmarks[keep]
landmarks = landmarks.reshape(-1, 2, 5).transpose(
(0, 2, 1)).reshape(-1, 10)

return bounding_boxes, landmarks

+ 100
- 0
modelscope/models/cv/face_detection/mtcnn/models/first_stage.py View File

@@ -0,0 +1,100 @@
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
import math

import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable

from .box_utils import _preprocess, nms


def run_first_stage(image, net, scale, threshold, device='cuda'):
"""Run P-Net, generate bounding boxes, and do NMS.

Arguments:
image: an instance of PIL.Image.
net: an instance of pytorch's nn.Module, P-Net.
scale: a float number,
scale width and height of the image by this number.
threshold: a float number,
threshold on the probability of a face when generating
bounding boxes from predictions of the net.

Returns:
a float numpy array of shape [n_boxes, 9],
bounding boxes with scores and offsets (4 + 1 + 4).
"""

# scale the image and convert it to a float array
width, height = image.size
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
img = image.resize((sw, sh), Image.BILINEAR)
img = np.asarray(img, 'float32')

img = Variable(
torch.FloatTensor(_preprocess(img)), volatile=True).to(device)
output = net(img)
probs = output[1].cpu().data.numpy()[0, 1, :, :]
offsets = output[0].cpu().data.numpy()
# probs: probability of a face at each sliding window
# offsets: transformations to true bounding boxes

boxes = _generate_bboxes(probs, offsets, scale, threshold)
if len(boxes) == 0:
return None

keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
return boxes[keep]


def _generate_bboxes(probs, offsets, scale, threshold):
"""Generate bounding boxes at places
where there is probably a face.

Arguments:
probs: a float numpy array of shape [n, m].
offsets: a float numpy array of shape [1, 4, n, m].
scale: a float number,
width and height of the image were scaled by this number.
threshold: a float number.

Returns:
a float numpy array of shape [n_boxes, 9]
"""

# applying P-Net is equivalent, in some sense, to
# moving 12x12 window with stride 2
stride = 2
cell_size = 12

# indices of boxes where there is probably a face
inds = np.where(probs > threshold)

if inds[0].size == 0:
return np.array([])

# transformations of bounding boxes
tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
# they are defined as:
# w = x2 - x1 + 1
# h = y2 - y1 + 1
# x1_true = x1 + tx1*w
# x2_true = x2 + tx2*w
# y1_true = y1 + ty1*h
# y2_true = y2 + ty2*h

offsets = np.array([tx1, ty1, tx2, ty2])
score = probs[inds[0], inds[1]]

# P-Net is applied to scaled images
# so we need to rescale bounding boxes back
bounding_boxes = np.vstack([
np.round((stride * inds[1] + 1.0) / scale),
np.round((stride * inds[0] + 1.0) / scale),
np.round((stride * inds[1] + 1.0 + cell_size) / scale),
np.round((stride * inds[0] + 1.0 + cell_size) / scale), score, offsets
])
# why one is added?

return bounding_boxes.T

+ 160
- 0
modelscope/models/cv/face_detection/mtcnn/models/get_nets.py View File

@@ -0,0 +1,160 @@
# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Flatten(nn.Module):

def __init__(self):
super(Flatten, self).__init__()

def forward(self, x):
"""
Arguments:
x: a float tensor with shape [batch_size, c, h, w].
Returns:
a float tensor with shape [batch_size, c*h*w].
"""

# without this pretrained model isn't working
x = x.transpose(3, 2).contiguous()

return x.view(x.size(0), -1)


class PNet(nn.Module):

def __init__(self, model_path=None):

super(PNet, self).__init__()

# suppose we have input with size HxW, then
# after first layer: H - 2,
# after pool: ceil((H - 2)/2),
# after second conv: ceil((H - 2)/2) - 2,
# after last conv: ceil((H - 2)/2) - 4,
# and the same for W

self.features = nn.Sequential(
OrderedDict([('conv1', nn.Conv2d(3, 10, 3, 1)),
('prelu1', nn.PReLU(10)),
('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
('conv2', nn.Conv2d(10, 16, 3, 1)),
('prelu2', nn.PReLU(16)),
('conv3', nn.Conv2d(16, 32, 3, 1)),
('prelu3', nn.PReLU(32))]))

self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
self.conv4_2 = nn.Conv2d(32, 4, 1, 1)

weights = np.load(model_path, allow_pickle=True)[()]
for n, p in self.named_parameters():
p.data = torch.FloatTensor(weights[n])

def forward(self, x):
"""
Arguments:
x: a float tensor with shape [batch_size, 3, h, w].
Returns:
b: a float tensor with shape [batch_size, 4, h', w'].
a: a float tensor with shape [batch_size, 2, h', w'].
"""
x = self.features(x)
a = self.conv4_1(x)
b = self.conv4_2(x)
a = F.softmax(a)
return b, a


class RNet(nn.Module):

def __init__(self, model_path=None):

super(RNet, self).__init__()

self.features = nn.Sequential(
OrderedDict([('conv1', nn.Conv2d(3, 28, 3, 1)),
('prelu1', nn.PReLU(28)),
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv2', nn.Conv2d(28, 48, 3, 1)),
('prelu2', nn.PReLU(48)),
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv3', nn.Conv2d(48, 64, 2, 1)),
('prelu3', nn.PReLU(64)), ('flatten', Flatten()),
('conv4', nn.Linear(576, 128)),
('prelu4', nn.PReLU(128))]))

self.conv5_1 = nn.Linear(128, 2)
self.conv5_2 = nn.Linear(128, 4)

weights = np.load(model_path, allow_pickle=True)[()]
for n, p in self.named_parameters():
p.data = torch.FloatTensor(weights[n])

def forward(self, x):
"""
Arguments:
x: a float tensor with shape [batch_size, 3, h, w].
Returns:
b: a float tensor with shape [batch_size, 4].
a: a float tensor with shape [batch_size, 2].
"""
x = self.features(x)
a = self.conv5_1(x)
b = self.conv5_2(x)
a = F.softmax(a)
return b, a


class ONet(nn.Module):

def __init__(self, model_path=None):

super(ONet, self).__init__()

self.features = nn.Sequential(
OrderedDict([
('conv1', nn.Conv2d(3, 32, 3, 1)),
('prelu1', nn.PReLU(32)),
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv2', nn.Conv2d(32, 64, 3, 1)),
('prelu2', nn.PReLU(64)),
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv3', nn.Conv2d(64, 64, 3, 1)),
('prelu3', nn.PReLU(64)),
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
('conv4', nn.Conv2d(64, 128, 2, 1)),
('prelu4', nn.PReLU(128)),
('flatten', Flatten()),
('conv5', nn.Linear(1152, 256)),
('drop5', nn.Dropout(0.25)),
('prelu5', nn.PReLU(256)),
]))

self.conv6_1 = nn.Linear(256, 2)
self.conv6_2 = nn.Linear(256, 4)
self.conv6_3 = nn.Linear(256, 10)

weights = np.load(model_path, allow_pickle=True)[()]
for n, p in self.named_parameters():
p.data = torch.FloatTensor(weights[n])

def forward(self, x):
"""
Arguments:
x: a float tensor with shape [batch_size, 3, h, w].
Returns:
c: a float tensor with shape [batch_size, 10].
b: a float tensor with shape [batch_size, 4].
a: a float tensor with shape [batch_size, 2].
"""
x = self.features(x)
a = self.conv6_1(x)
b = self.conv6_2(x)
c = self.conv6_3(x)
a = F.softmax(a)
return c, b, a

+ 3
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -51,6 +51,7 @@ if TYPE_CHECKING:
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline


else: else:
_import_structure = { _import_structure = {
@@ -114,7 +115,8 @@ else:
'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'],
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'],
'facial_expression_recognition_pipelin': 'facial_expression_recognition_pipelin':
['FacialExpressionRecognitionPipeline']
['FacialExpressionRecognitionPipeline'],
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
} }


import sys import sys


+ 56
- 0
modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py View File

@@ -0,0 +1,56 @@
import os.path as osp
from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_detection import MtcnnFaceDetector
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_detection, module_name=Pipelines.mtcnn_face_detection)
class MtcnnFaceDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a face detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
ckpt_path = osp.join(model, './weights')
logger.info(f'loading model from {ckpt_path}')
device = torch.device(
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
detector = MtcnnFaceDetector(model_path=ckpt_path, device=device)
self.detector = detector
self.device = device
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
result = self.detector(input)
assert result is not None
bboxes = result[0][:, :4].tolist()
scores = result[0][:, 4].tolist()
lms = result[1].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: lms,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 38
- 0
tests/pipelines/test_mtcnn_face_detection.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2
from PIL import Image

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_face_detection_result
from modelscope.utils.test_utils import test_level


class MtcnnFaceDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_manual_face-detection_mtcnn'

def show_result(self, img_path, detection_result):
img = draw_face_detection_result(img_path, detection_result)
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
img_path = 'data/test/images/mtcnn_face_detection.jpg'
img = Image.open(img_path)

result_1 = face_detection(img_path)
self.show_result(img_path, result_1)

result_2 = face_detection(img)
self.show_result(img_path, result_2)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save