From c12957a9eb0753b61285cfa44f0d34d72b3e52ba Mon Sep 17 00:00:00 2001 From: ly261666 Date: Tue, 6 Sep 2022 22:53:55 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=20=E6=96=B0=E5=A2=9EMtcnn?= =?UTF-8?q?=E4=BA=BA=E8=84=B8=E6=A3=80=E6=B5=8B=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 完成Maas-cv CR标准 自查 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9951519 * [to #42322933] 新增Mtcnn人脸检测器 --- data/test/images/mtcnn_face_detection.jpg | 3 + modelscope/metainfo.py | 2 + .../models/cv/face_detection/__init__.py | 5 +- .../cv/face_detection/mtcnn/__init__.py | 1 + .../face_detection/mtcnn/models/__init__.py | 0 .../face_detection/mtcnn/models/box_utils.py | 240 ++++++++++++++++++ .../face_detection/mtcnn/models/detector.py | 149 +++++++++++ .../mtcnn/models/first_stage.py | 100 ++++++++ .../face_detection/mtcnn/models/get_nets.py | 160 ++++++++++++ modelscope/pipelines/cv/__init__.py | 4 +- .../cv/mtcnn_face_detection_pipeline.py | 56 ++++ tests/pipelines/test_mtcnn_face_detection.py | 38 +++ 12 files changed, 756 insertions(+), 2 deletions(-) create mode 100644 data/test/images/mtcnn_face_detection.jpg create mode 100644 modelscope/models/cv/face_detection/mtcnn/__init__.py create mode 100644 modelscope/models/cv/face_detection/mtcnn/models/__init__.py create mode 100644 modelscope/models/cv/face_detection/mtcnn/models/box_utils.py create mode 100644 modelscope/models/cv/face_detection/mtcnn/models/detector.py create mode 100644 modelscope/models/cv/face_detection/mtcnn/models/first_stage.py create mode 100644 modelscope/models/cv/face_detection/mtcnn/models/get_nets.py create mode 100644 modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py create mode 100644 tests/pipelines/test_mtcnn_face_detection.py diff --git a/data/test/images/mtcnn_face_detection.jpg b/data/test/images/mtcnn_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/mtcnn_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d7217d57..d7594794 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -35,6 +35,7 @@ class Models(object): fer = 'fer' retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' + mtcnn = 'mtcnn' ulfd = 'ulfd' # EasyCV models @@ -127,6 +128,7 @@ class Pipelines(object): ulfd_face_detection = 'manual-face-detection-ulfd' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' + mtcnn_face_detection = 'manual-face-detection-mtcnn' live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels' diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index 63ff1b83..ed8832c2 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -4,12 +4,15 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .mtcnn import MtcnnFaceDetector from .retinaface import RetinaFaceDetection from .ulfd_slim import UlfdFaceDetector + else: _import_structure = { 'ulfd_slim': ['UlfdFaceDetector'], - 'retinaface': ['RetinaFaceDetection'] + 'retinaface': ['RetinaFaceDetection'], + 'mtcnn': ['MtcnnFaceDetector'] } import sys diff --git a/modelscope/models/cv/face_detection/mtcnn/__init__.py b/modelscope/models/cv/face_detection/mtcnn/__init__.py new file mode 100644 index 00000000..b11c4740 --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/__init__.py @@ -0,0 +1 @@ +from .models.detector import MtcnnFaceDetector diff --git a/modelscope/models/cv/face_detection/mtcnn/models/__init__.py b/modelscope/models/cv/face_detection/mtcnn/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py b/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py new file mode 100644 index 00000000..f6a27b05 --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py @@ -0,0 +1,240 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import numpy as np +from PIL import Image + + +def nms(boxes, overlap_threshold=0.5, mode='union'): + """Non-maximum suppression. + + Arguments: + boxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + overlap_threshold: a float number. + mode: 'union' or 'min'. + + Returns: + list with indices of the selected boxes + """ + + # if there are no boxes, return the empty list + if len(boxes) == 0: + return [] + + # list of picked indices + pick = [] + + # grab the coordinates of the bounding boxes + x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] + + area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) + ids = np.argsort(score) # in increasing order + + while len(ids) > 0: + + # grab index of the largest value + last = len(ids) - 1 + i = ids[last] + pick.append(i) + + # compute intersections + # of the box with the largest score + # with the rest of boxes + + # left top corner of intersection boxes + ix1 = np.maximum(x1[i], x1[ids[:last]]) + iy1 = np.maximum(y1[i], y1[ids[:last]]) + + # right bottom corner of intersection boxes + ix2 = np.minimum(x2[i], x2[ids[:last]]) + iy2 = np.minimum(y2[i], y2[ids[:last]]) + + # width and height of intersection boxes + w = np.maximum(0.0, ix2 - ix1 + 1.0) + h = np.maximum(0.0, iy2 - iy1 + 1.0) + + # intersections' areas + inter = w * h + if mode == 'min': + overlap = inter / np.minimum(area[i], area[ids[:last]]) + elif mode == 'union': + # intersection over union (IoU) + overlap = inter / (area[i] + area[ids[:last]] - inter) + + # delete all boxes where overlap is too big + ids = np.delete( + ids, + np.concatenate([[last], + np.where(overlap > overlap_threshold)[0]])) + + return pick + + +def convert_to_square(bboxes): + """Convert bounding boxes to a square form. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + + Returns: + a float numpy array of shape [n, 5], + squared bounding boxes. + """ + + square_bboxes = np.zeros_like(bboxes) + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + h = y2 - y1 + 1.0 + w = x2 - x1 + 1.0 + max_side = np.maximum(h, w) + square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 + square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 + square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 + square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 + return square_bboxes + + +def calibrate_box(bboxes, offsets): + """Transform bounding boxes to be more like true bounding boxes. + 'offsets' is one of the outputs of the nets. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + offsets: a float numpy array of shape [n, 4]. + + Returns: + a float numpy array of shape [n, 5]. + """ + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w = x2 - x1 + 1.0 + h = y2 - y1 + 1.0 + w = np.expand_dims(w, 1) + h = np.expand_dims(h, 1) + + # this is what happening here: + # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] + # x1_true = x1 + tx1*w + # y1_true = y1 + ty1*h + # x2_true = x2 + tx2*w + # y2_true = y2 + ty2*h + # below is just more compact form of this + + # are offsets always such that + # x1 < x2 and y1 < y2 ? + + translation = np.hstack([w, h, w, h]) * offsets + bboxes[:, 0:4] = bboxes[:, 0:4] + translation + return bboxes + + +def get_image_boxes(bounding_boxes, img, size=24): + """Cut out boxes from the image. + + Arguments: + bounding_boxes: a float numpy array of shape [n, 5]. + img: an instance of PIL.Image. + size: an integer, size of cutouts. + + Returns: + a float numpy array of shape [n, 3, size, size]. + """ + + num_boxes = len(bounding_boxes) + width, height = img.size + + [dy, edy, dx, edx, y, ey, x, ex, w, + h] = correct_bboxes(bounding_boxes, width, height) + img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') + + for i in range(num_boxes): + img_box = np.zeros((h[i], w[i], 3), 'uint8') + + img_array = np.asarray(img, 'uint8') + img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\ + img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] + + # resize + img_box = Image.fromarray(img_box) + img_box = img_box.resize((size, size), Image.BILINEAR) + img_box = np.asarray(img_box, 'float32') + + img_boxes[i, :, :, :] = _preprocess(img_box) + + return img_boxes + + +def correct_bboxes(bboxes, width, height): + """Crop boxes that are too big and get coordinates + with respect to cutouts. + + Arguments: + bboxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + width: a float number. + height: a float number. + + Returns: + dy, dx, edy, edx: a int numpy arrays of shape [n], + coordinates of the boxes with respect to the cutouts. + y, x, ey, ex: a int numpy arrays of shape [n], + corrected ymin, xmin, ymax, xmax. + h, w: a int numpy arrays of shape [n], + just heights and widths of boxes. + + in the following order: + [dy, edy, dx, edx, y, ey, x, ex, w, h]. + """ + + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 + num_boxes = bboxes.shape[0] + + # 'e' stands for end + # (x, y) -> (ex, ey) + x, y, ex, ey = x1, y1, x2, y2 + + # we need to cut out a box from the image. + # (x, y, ex, ey) are corrected coordinates of the box + # in the image. + # (dx, dy, edx, edy) are coordinates of the box in the cutout + # from the image. + dx, dy = np.zeros((num_boxes, )), np.zeros((num_boxes, )) + edx, edy = w.copy() - 1.0, h.copy() - 1.0 + + # if box's bottom right corner is too far right + ind = np.where(ex > width - 1.0)[0] + edx[ind] = w[ind] + width - 2.0 - ex[ind] + ex[ind] = width - 1.0 + + # if box's bottom right corner is too low + ind = np.where(ey > height - 1.0)[0] + edy[ind] = h[ind] + height - 2.0 - ey[ind] + ey[ind] = height - 1.0 + + # if box's top left corner is too far left + ind = np.where(x < 0.0)[0] + dx[ind] = 0.0 - x[ind] + x[ind] = 0.0 + + # if box's top left corner is too high + ind = np.where(y < 0.0)[0] + dy[ind] = 0.0 - y[ind] + y[ind] = 0.0 + + return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] + return_list = [i.astype('int32') for i in return_list] + + return return_list + + +def _preprocess(img): + """Preprocessing step before feeding the network. + + Arguments: + img: a float numpy array of shape [h, w, c]. + + Returns: + a float numpy array of shape [1, c, h, w]. + """ + img = img.transpose((2, 0, 1)) + img = np.expand_dims(img, 0) + img = (img - 127.5) * 0.0078125 + return img diff --git a/modelscope/models/cv/face_detection/mtcnn/models/detector.py b/modelscope/models/cv/face_detection/mtcnn/models/detector.py new file mode 100644 index 00000000..9c3aca3a --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/detector.py @@ -0,0 +1,149 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import os + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from PIL import Image +from torch.autograd import Variable + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .box_utils import calibrate_box, convert_to_square, get_image_boxes, nms +from .first_stage import run_first_stage +from .get_nets import ONet, PNet, RNet + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.mtcnn) +class MtcnnFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + + self.pnet = PNet(model_path=os.path.join(self.model_path, 'pnet.npy')) + self.rnet = RNet(model_path=os.path.join(self.model_path, 'rnet.npy')) + self.onet = ONet(model_path=os.path.join(self.model_path, 'onet.npy')) + + self.pnet = self.pnet.to(device) + self.rnet = self.rnet.to(device) + self.onet = self.onet.to(device) + + def forward(self, input): + image = Image.fromarray(np.uint8(input['img'].cpu().numpy())) + pnet = self.pnet + rnet = self.rnet + onet = self.onet + onet.eval() + + min_face_size = 20.0 + thresholds = [0.7, 0.8, 0.9] + nms_thresholds = [0.7, 0.7, 0.7] + + # BUILD AN IMAGE PYRAMID + width, height = image.size + min_length = min(height, width) + + min_detection_size = 12 + factor = 0.707 # sqrt(0.5) + + # scales for scaling the image + scales = [] + + m = min_detection_size / min_face_size + min_length *= m + + factor_count = 0 + while min_length > min_detection_size: + scales.append(m * factor**factor_count) + min_length *= factor + factor_count += 1 + + # STAGE 1 + + # it will be returned + bounding_boxes = [] + + # run P-Net on different scales + for s in scales: + boxes = run_first_stage( + image, + pnet, + scale=s, + threshold=thresholds[0], + device=self.device) + bounding_boxes.append(boxes) + + # collect boxes (and offsets, and scores) from different scales + bounding_boxes = [i for i in bounding_boxes if i is not None] + bounding_boxes = np.vstack(bounding_boxes) + + keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) + bounding_boxes = bounding_boxes[keep] + + # use offsets predicted by pnet to transform bounding boxes + bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], + bounding_boxes[:, 5:]) + # shape [n_boxes, 5] + + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 2 + + img_boxes = get_image_boxes(bounding_boxes, image, size=24) + img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) + output = rnet(img_boxes.to(self.device)) + offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[1])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) + offsets = offsets[keep] + + keep = nms(bounding_boxes, nms_thresholds[1]) + bounding_boxes = bounding_boxes[keep] + bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 3 + + img_boxes = get_image_boxes(bounding_boxes, image, size=48) + if len(img_boxes) == 0: + return [], [] + img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) + output = onet(img_boxes.to(self.device)) + landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] + offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[2])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) + offsets = offsets[keep] + landmarks = landmarks[keep] + + # compute landmark points + width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 + height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 + xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] + landmarks[:, 0:5] = np.expand_dims( + xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] + landmarks[:, 5:10] = np.expand_dims( + ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] + + bounding_boxes = calibrate_box(bounding_boxes, offsets) + keep = nms(bounding_boxes, nms_thresholds[2], mode='min') + bounding_boxes = bounding_boxes[keep] + landmarks = landmarks[keep] + landmarks = landmarks.reshape(-1, 2, 5).transpose( + (0, 2, 1)).reshape(-1, 10) + + return bounding_boxes, landmarks diff --git a/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py b/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py new file mode 100644 index 00000000..e2aba47e --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py @@ -0,0 +1,100 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import math + +import numpy as np +import torch +from PIL import Image +from torch.autograd import Variable + +from .box_utils import _preprocess, nms + + +def run_first_stage(image, net, scale, threshold, device='cuda'): + """Run P-Net, generate bounding boxes, and do NMS. + + Arguments: + image: an instance of PIL.Image. + net: an instance of pytorch's nn.Module, P-Net. + scale: a float number, + scale width and height of the image by this number. + threshold: a float number, + threshold on the probability of a face when generating + bounding boxes from predictions of the net. + + Returns: + a float numpy array of shape [n_boxes, 9], + bounding boxes with scores and offsets (4 + 1 + 4). + """ + + # scale the image and convert it to a float array + width, height = image.size + sw, sh = math.ceil(width * scale), math.ceil(height * scale) + img = image.resize((sw, sh), Image.BILINEAR) + img = np.asarray(img, 'float32') + + img = Variable( + torch.FloatTensor(_preprocess(img)), volatile=True).to(device) + output = net(img) + probs = output[1].cpu().data.numpy()[0, 1, :, :] + offsets = output[0].cpu().data.numpy() + # probs: probability of a face at each sliding window + # offsets: transformations to true bounding boxes + + boxes = _generate_bboxes(probs, offsets, scale, threshold) + if len(boxes) == 0: + return None + + keep = nms(boxes[:, 0:5], overlap_threshold=0.5) + return boxes[keep] + + +def _generate_bboxes(probs, offsets, scale, threshold): + """Generate bounding boxes at places + where there is probably a face. + + Arguments: + probs: a float numpy array of shape [n, m]. + offsets: a float numpy array of shape [1, 4, n, m]. + scale: a float number, + width and height of the image were scaled by this number. + threshold: a float number. + + Returns: + a float numpy array of shape [n_boxes, 9] + """ + + # applying P-Net is equivalent, in some sense, to + # moving 12x12 window with stride 2 + stride = 2 + cell_size = 12 + + # indices of boxes where there is probably a face + inds = np.where(probs > threshold) + + if inds[0].size == 0: + return np.array([]) + + # transformations of bounding boxes + tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] + # they are defined as: + # w = x2 - x1 + 1 + # h = y2 - y1 + 1 + # x1_true = x1 + tx1*w + # x2_true = x2 + tx2*w + # y1_true = y1 + ty1*h + # y2_true = y2 + ty2*h + + offsets = np.array([tx1, ty1, tx2, ty2]) + score = probs[inds[0], inds[1]] + + # P-Net is applied to scaled images + # so we need to rescale bounding boxes back + bounding_boxes = np.vstack([ + np.round((stride * inds[1] + 1.0) / scale), + np.round((stride * inds[0] + 1.0) / scale), + np.round((stride * inds[1] + 1.0 + cell_size) / scale), + np.round((stride * inds[0] + 1.0 + cell_size) / scale), score, offsets + ]) + # why one is added? + + return bounding_boxes.T diff --git a/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py b/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py new file mode 100644 index 00000000..5fbbd33b --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py @@ -0,0 +1,160 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Flatten(nn.Module): + + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, c, h, w]. + Returns: + a float tensor with shape [batch_size, c*h*w]. + """ + + # without this pretrained model isn't working + x = x.transpose(3, 2).contiguous() + + return x.view(x.size(0), -1) + + +class PNet(nn.Module): + + def __init__(self, model_path=None): + + super(PNet, self).__init__() + + # suppose we have input with size HxW, then + # after first layer: H - 2, + # after pool: ceil((H - 2)/2), + # after second conv: ceil((H - 2)/2) - 2, + # after last conv: ceil((H - 2)/2) - 4, + # and the same for W + + self.features = nn.Sequential( + OrderedDict([('conv1', nn.Conv2d(3, 10, 3, 1)), + ('prelu1', nn.PReLU(10)), + ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(10, 16, 3, 1)), + ('prelu2', nn.PReLU(16)), + ('conv3', nn.Conv2d(16, 32, 3, 1)), + ('prelu3', nn.PReLU(32))])) + + self.conv4_1 = nn.Conv2d(32, 2, 1, 1) + self.conv4_2 = nn.Conv2d(32, 4, 1, 1) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4, h', w']. + a: a float tensor with shape [batch_size, 2, h', w']. + """ + x = self.features(x) + a = self.conv4_1(x) + b = self.conv4_2(x) + a = F.softmax(a) + return b, a + + +class RNet(nn.Module): + + def __init__(self, model_path=None): + + super(RNet, self).__init__() + + self.features = nn.Sequential( + OrderedDict([('conv1', nn.Conv2d(3, 28, 3, 1)), + ('prelu1', nn.PReLU(28)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(28, 48, 3, 1)), + ('prelu2', nn.PReLU(48)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv3', nn.Conv2d(48, 64, 2, 1)), + ('prelu3', nn.PReLU(64)), ('flatten', Flatten()), + ('conv4', nn.Linear(576, 128)), + ('prelu4', nn.PReLU(128))])) + + self.conv5_1 = nn.Linear(128, 2) + self.conv5_2 = nn.Linear(128, 4) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv5_1(x) + b = self.conv5_2(x) + a = F.softmax(a) + return b, a + + +class ONet(nn.Module): + + def __init__(self, model_path=None): + + super(ONet, self).__init__() + + self.features = nn.Sequential( + OrderedDict([ + ('conv1', nn.Conv2d(3, 32, 3, 1)), + ('prelu1', nn.PReLU(32)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(32, 64, 3, 1)), + ('prelu2', nn.PReLU(64)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv3', nn.Conv2d(64, 64, 3, 1)), + ('prelu3', nn.PReLU(64)), + ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), + ('conv4', nn.Conv2d(64, 128, 2, 1)), + ('prelu4', nn.PReLU(128)), + ('flatten', Flatten()), + ('conv5', nn.Linear(1152, 256)), + ('drop5', nn.Dropout(0.25)), + ('prelu5', nn.PReLU(256)), + ])) + + self.conv6_1 = nn.Linear(256, 2) + self.conv6_2 = nn.Linear(256, 4) + self.conv6_3 = nn.Linear(256, 10) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + c: a float tensor with shape [batch_size, 10]. + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv6_1(x) + b = self.conv6_2(x) + c = self.conv6_3(x) + a = F.softmax(a) + return c, b, a diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 02682fa0..3eb5cd82 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline + from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline else: _import_structure = { @@ -114,7 +115,8 @@ else: 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], 'facial_expression_recognition_pipelin': - ['FacialExpressionRecognitionPipeline'] + ['FacialExpressionRecognitionPipeline'], + 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], } import sys diff --git a/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py b/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py new file mode 100644 index 00000000..57bf9920 --- /dev/null +++ b/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py @@ -0,0 +1,56 @@ +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import MtcnnFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.mtcnn_face_detection) +class MtcnnFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, './weights') + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + detector = MtcnnFaceDetector(model_path=ckpt_path, device=device) + self.detector = detector + self.device = device + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0][:, :4].tolist() + scores = result[0][:, 4].tolist() + lms = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: lms, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_mtcnn_face_detection.py b/tests/pipelines/test_mtcnn_face_detection.py new file mode 100644 index 00000000..5afb5588 --- /dev/null +++ b/tests/pipelines/test_mtcnn_face_detection.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 +from PIL import Image + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.test_utils import test_level + + +class MtcnnFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_manual_face-detection_mtcnn' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/mtcnn_face_detection.jpg' + img = Image.open(img_path) + + result_1 = face_detection(img_path) + self.show_result(img_path, result_1) + + result_2 = face_detection(img) + self.show_result(img_path, result_2) + + +if __name__ == '__main__': + unittest.main()