1. 完成Maas-cv CR标准 自查
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9951519
* [to #42322933] 新增Mtcnn人脸检测器
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 | |||||
| size 87228 | |||||
| @@ -35,6 +35,7 @@ class Models(object): | |||||
| fer = 'fer' | fer = 'fer' | ||||
| retinaface = 'retinaface' | retinaface = 'retinaface' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| mtcnn = 'mtcnn' | |||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| # EasyCV models | # EasyCV models | ||||
| @@ -127,6 +128,7 @@ class Pipelines(object): | |||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| mtcnn_face_detection = 'manual-face-detection-mtcnn' | |||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | ||||
| @@ -4,12 +4,15 @@ from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .mtcnn import MtcnnFaceDetector | |||||
| from .retinaface import RetinaFaceDetection | from .retinaface import RetinaFaceDetection | ||||
| from .ulfd_slim import UlfdFaceDetector | from .ulfd_slim import UlfdFaceDetector | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'ulfd_slim': ['UlfdFaceDetector'], | 'ulfd_slim': ['UlfdFaceDetector'], | ||||
| 'retinaface': ['RetinaFaceDetection'] | |||||
| 'retinaface': ['RetinaFaceDetection'], | |||||
| 'mtcnn': ['MtcnnFaceDetector'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1 @@ | |||||
| from .models.detector import MtcnnFaceDetector | |||||
| @@ -0,0 +1,240 @@ | |||||
| # The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| def nms(boxes, overlap_threshold=0.5, mode='union'): | |||||
| """Non-maximum suppression. | |||||
| Arguments: | |||||
| boxes: a float numpy array of shape [n, 5], | |||||
| where each row is (xmin, ymin, xmax, ymax, score). | |||||
| overlap_threshold: a float number. | |||||
| mode: 'union' or 'min'. | |||||
| Returns: | |||||
| list with indices of the selected boxes | |||||
| """ | |||||
| # if there are no boxes, return the empty list | |||||
| if len(boxes) == 0: | |||||
| return [] | |||||
| # list of picked indices | |||||
| pick = [] | |||||
| # grab the coordinates of the bounding boxes | |||||
| x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] | |||||
| area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) | |||||
| ids = np.argsort(score) # in increasing order | |||||
| while len(ids) > 0: | |||||
| # grab index of the largest value | |||||
| last = len(ids) - 1 | |||||
| i = ids[last] | |||||
| pick.append(i) | |||||
| # compute intersections | |||||
| # of the box with the largest score | |||||
| # with the rest of boxes | |||||
| # left top corner of intersection boxes | |||||
| ix1 = np.maximum(x1[i], x1[ids[:last]]) | |||||
| iy1 = np.maximum(y1[i], y1[ids[:last]]) | |||||
| # right bottom corner of intersection boxes | |||||
| ix2 = np.minimum(x2[i], x2[ids[:last]]) | |||||
| iy2 = np.minimum(y2[i], y2[ids[:last]]) | |||||
| # width and height of intersection boxes | |||||
| w = np.maximum(0.0, ix2 - ix1 + 1.0) | |||||
| h = np.maximum(0.0, iy2 - iy1 + 1.0) | |||||
| # intersections' areas | |||||
| inter = w * h | |||||
| if mode == 'min': | |||||
| overlap = inter / np.minimum(area[i], area[ids[:last]]) | |||||
| elif mode == 'union': | |||||
| # intersection over union (IoU) | |||||
| overlap = inter / (area[i] + area[ids[:last]] - inter) | |||||
| # delete all boxes where overlap is too big | |||||
| ids = np.delete( | |||||
| ids, | |||||
| np.concatenate([[last], | |||||
| np.where(overlap > overlap_threshold)[0]])) | |||||
| return pick | |||||
| def convert_to_square(bboxes): | |||||
| """Convert bounding boxes to a square form. | |||||
| Arguments: | |||||
| bboxes: a float numpy array of shape [n, 5]. | |||||
| Returns: | |||||
| a float numpy array of shape [n, 5], | |||||
| squared bounding boxes. | |||||
| """ | |||||
| square_bboxes = np.zeros_like(bboxes) | |||||
| x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] | |||||
| h = y2 - y1 + 1.0 | |||||
| w = x2 - x1 + 1.0 | |||||
| max_side = np.maximum(h, w) | |||||
| square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 | |||||
| square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 | |||||
| square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 | |||||
| square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 | |||||
| return square_bboxes | |||||
| def calibrate_box(bboxes, offsets): | |||||
| """Transform bounding boxes to be more like true bounding boxes. | |||||
| 'offsets' is one of the outputs of the nets. | |||||
| Arguments: | |||||
| bboxes: a float numpy array of shape [n, 5]. | |||||
| offsets: a float numpy array of shape [n, 4]. | |||||
| Returns: | |||||
| a float numpy array of shape [n, 5]. | |||||
| """ | |||||
| x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] | |||||
| w = x2 - x1 + 1.0 | |||||
| h = y2 - y1 + 1.0 | |||||
| w = np.expand_dims(w, 1) | |||||
| h = np.expand_dims(h, 1) | |||||
| # this is what happening here: | |||||
| # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] | |||||
| # x1_true = x1 + tx1*w | |||||
| # y1_true = y1 + ty1*h | |||||
| # x2_true = x2 + tx2*w | |||||
| # y2_true = y2 + ty2*h | |||||
| # below is just more compact form of this | |||||
| # are offsets always such that | |||||
| # x1 < x2 and y1 < y2 ? | |||||
| translation = np.hstack([w, h, w, h]) * offsets | |||||
| bboxes[:, 0:4] = bboxes[:, 0:4] + translation | |||||
| return bboxes | |||||
| def get_image_boxes(bounding_boxes, img, size=24): | |||||
| """Cut out boxes from the image. | |||||
| Arguments: | |||||
| bounding_boxes: a float numpy array of shape [n, 5]. | |||||
| img: an instance of PIL.Image. | |||||
| size: an integer, size of cutouts. | |||||
| Returns: | |||||
| a float numpy array of shape [n, 3, size, size]. | |||||
| """ | |||||
| num_boxes = len(bounding_boxes) | |||||
| width, height = img.size | |||||
| [dy, edy, dx, edx, y, ey, x, ex, w, | |||||
| h] = correct_bboxes(bounding_boxes, width, height) | |||||
| img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') | |||||
| for i in range(num_boxes): | |||||
| img_box = np.zeros((h[i], w[i], 3), 'uint8') | |||||
| img_array = np.asarray(img, 'uint8') | |||||
| img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\ | |||||
| img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] | |||||
| # resize | |||||
| img_box = Image.fromarray(img_box) | |||||
| img_box = img_box.resize((size, size), Image.BILINEAR) | |||||
| img_box = np.asarray(img_box, 'float32') | |||||
| img_boxes[i, :, :, :] = _preprocess(img_box) | |||||
| return img_boxes | |||||
| def correct_bboxes(bboxes, width, height): | |||||
| """Crop boxes that are too big and get coordinates | |||||
| with respect to cutouts. | |||||
| Arguments: | |||||
| bboxes: a float numpy array of shape [n, 5], | |||||
| where each row is (xmin, ymin, xmax, ymax, score). | |||||
| width: a float number. | |||||
| height: a float number. | |||||
| Returns: | |||||
| dy, dx, edy, edx: a int numpy arrays of shape [n], | |||||
| coordinates of the boxes with respect to the cutouts. | |||||
| y, x, ey, ex: a int numpy arrays of shape [n], | |||||
| corrected ymin, xmin, ymax, xmax. | |||||
| h, w: a int numpy arrays of shape [n], | |||||
| just heights and widths of boxes. | |||||
| in the following order: | |||||
| [dy, edy, dx, edx, y, ey, x, ex, w, h]. | |||||
| """ | |||||
| x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] | |||||
| w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 | |||||
| num_boxes = bboxes.shape[0] | |||||
| # 'e' stands for end | |||||
| # (x, y) -> (ex, ey) | |||||
| x, y, ex, ey = x1, y1, x2, y2 | |||||
| # we need to cut out a box from the image. | |||||
| # (x, y, ex, ey) are corrected coordinates of the box | |||||
| # in the image. | |||||
| # (dx, dy, edx, edy) are coordinates of the box in the cutout | |||||
| # from the image. | |||||
| dx, dy = np.zeros((num_boxes, )), np.zeros((num_boxes, )) | |||||
| edx, edy = w.copy() - 1.0, h.copy() - 1.0 | |||||
| # if box's bottom right corner is too far right | |||||
| ind = np.where(ex > width - 1.0)[0] | |||||
| edx[ind] = w[ind] + width - 2.0 - ex[ind] | |||||
| ex[ind] = width - 1.0 | |||||
| # if box's bottom right corner is too low | |||||
| ind = np.where(ey > height - 1.0)[0] | |||||
| edy[ind] = h[ind] + height - 2.0 - ey[ind] | |||||
| ey[ind] = height - 1.0 | |||||
| # if box's top left corner is too far left | |||||
| ind = np.where(x < 0.0)[0] | |||||
| dx[ind] = 0.0 - x[ind] | |||||
| x[ind] = 0.0 | |||||
| # if box's top left corner is too high | |||||
| ind = np.where(y < 0.0)[0] | |||||
| dy[ind] = 0.0 - y[ind] | |||||
| y[ind] = 0.0 | |||||
| return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] | |||||
| return_list = [i.astype('int32') for i in return_list] | |||||
| return return_list | |||||
| def _preprocess(img): | |||||
| """Preprocessing step before feeding the network. | |||||
| Arguments: | |||||
| img: a float numpy array of shape [h, w, c]. | |||||
| Returns: | |||||
| a float numpy array of shape [1, c, h, w]. | |||||
| """ | |||||
| img = img.transpose((2, 0, 1)) | |||||
| img = np.expand_dims(img, 0) | |||||
| img = (img - 127.5) * 0.0078125 | |||||
| return img | |||||
| @@ -0,0 +1,149 @@ | |||||
| # The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch | |||||
| import os | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| from PIL import Image | |||||
| from torch.autograd import Variable | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .box_utils import calibrate_box, convert_to_square, get_image_boxes, nms | |||||
| from .first_stage import run_first_stage | |||||
| from .get_nets import ONet, PNet, RNet | |||||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.mtcnn) | |||||
| class MtcnnFaceDetector(TorchModel): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| super().__init__(model_path) | |||||
| torch.set_grad_enabled(False) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.device = device | |||||
| self.pnet = PNet(model_path=os.path.join(self.model_path, 'pnet.npy')) | |||||
| self.rnet = RNet(model_path=os.path.join(self.model_path, 'rnet.npy')) | |||||
| self.onet = ONet(model_path=os.path.join(self.model_path, 'onet.npy')) | |||||
| self.pnet = self.pnet.to(device) | |||||
| self.rnet = self.rnet.to(device) | |||||
| self.onet = self.onet.to(device) | |||||
| def forward(self, input): | |||||
| image = Image.fromarray(np.uint8(input['img'].cpu().numpy())) | |||||
| pnet = self.pnet | |||||
| rnet = self.rnet | |||||
| onet = self.onet | |||||
| onet.eval() | |||||
| min_face_size = 20.0 | |||||
| thresholds = [0.7, 0.8, 0.9] | |||||
| nms_thresholds = [0.7, 0.7, 0.7] | |||||
| # BUILD AN IMAGE PYRAMID | |||||
| width, height = image.size | |||||
| min_length = min(height, width) | |||||
| min_detection_size = 12 | |||||
| factor = 0.707 # sqrt(0.5) | |||||
| # scales for scaling the image | |||||
| scales = [] | |||||
| m = min_detection_size / min_face_size | |||||
| min_length *= m | |||||
| factor_count = 0 | |||||
| while min_length > min_detection_size: | |||||
| scales.append(m * factor**factor_count) | |||||
| min_length *= factor | |||||
| factor_count += 1 | |||||
| # STAGE 1 | |||||
| # it will be returned | |||||
| bounding_boxes = [] | |||||
| # run P-Net on different scales | |||||
| for s in scales: | |||||
| boxes = run_first_stage( | |||||
| image, | |||||
| pnet, | |||||
| scale=s, | |||||
| threshold=thresholds[0], | |||||
| device=self.device) | |||||
| bounding_boxes.append(boxes) | |||||
| # collect boxes (and offsets, and scores) from different scales | |||||
| bounding_boxes = [i for i in bounding_boxes if i is not None] | |||||
| bounding_boxes = np.vstack(bounding_boxes) | |||||
| keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) | |||||
| bounding_boxes = bounding_boxes[keep] | |||||
| # use offsets predicted by pnet to transform bounding boxes | |||||
| bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], | |||||
| bounding_boxes[:, 5:]) | |||||
| # shape [n_boxes, 5] | |||||
| bounding_boxes = convert_to_square(bounding_boxes) | |||||
| bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) | |||||
| # STAGE 2 | |||||
| img_boxes = get_image_boxes(bounding_boxes, image, size=24) | |||||
| img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) | |||||
| output = rnet(img_boxes.to(self.device)) | |||||
| offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] | |||||
| probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] | |||||
| keep = np.where(probs[:, 1] > thresholds[1])[0] | |||||
| bounding_boxes = bounding_boxes[keep] | |||||
| bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) | |||||
| offsets = offsets[keep] | |||||
| keep = nms(bounding_boxes, nms_thresholds[1]) | |||||
| bounding_boxes = bounding_boxes[keep] | |||||
| bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) | |||||
| bounding_boxes = convert_to_square(bounding_boxes) | |||||
| bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) | |||||
| # STAGE 3 | |||||
| img_boxes = get_image_boxes(bounding_boxes, image, size=48) | |||||
| if len(img_boxes) == 0: | |||||
| return [], [] | |||||
| img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) | |||||
| output = onet(img_boxes.to(self.device)) | |||||
| landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] | |||||
| offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] | |||||
| probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] | |||||
| keep = np.where(probs[:, 1] > thresholds[2])[0] | |||||
| bounding_boxes = bounding_boxes[keep] | |||||
| bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) | |||||
| offsets = offsets[keep] | |||||
| landmarks = landmarks[keep] | |||||
| # compute landmark points | |||||
| width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 | |||||
| height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 | |||||
| xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] | |||||
| landmarks[:, 0:5] = np.expand_dims( | |||||
| xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] | |||||
| landmarks[:, 5:10] = np.expand_dims( | |||||
| ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] | |||||
| bounding_boxes = calibrate_box(bounding_boxes, offsets) | |||||
| keep = nms(bounding_boxes, nms_thresholds[2], mode='min') | |||||
| bounding_boxes = bounding_boxes[keep] | |||||
| landmarks = landmarks[keep] | |||||
| landmarks = landmarks.reshape(-1, 2, 5).transpose( | |||||
| (0, 2, 1)).reshape(-1, 10) | |||||
| return bounding_boxes, landmarks | |||||
| @@ -0,0 +1,100 @@ | |||||
| # The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch | |||||
| import math | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torch.autograd import Variable | |||||
| from .box_utils import _preprocess, nms | |||||
| def run_first_stage(image, net, scale, threshold, device='cuda'): | |||||
| """Run P-Net, generate bounding boxes, and do NMS. | |||||
| Arguments: | |||||
| image: an instance of PIL.Image. | |||||
| net: an instance of pytorch's nn.Module, P-Net. | |||||
| scale: a float number, | |||||
| scale width and height of the image by this number. | |||||
| threshold: a float number, | |||||
| threshold on the probability of a face when generating | |||||
| bounding boxes from predictions of the net. | |||||
| Returns: | |||||
| a float numpy array of shape [n_boxes, 9], | |||||
| bounding boxes with scores and offsets (4 + 1 + 4). | |||||
| """ | |||||
| # scale the image and convert it to a float array | |||||
| width, height = image.size | |||||
| sw, sh = math.ceil(width * scale), math.ceil(height * scale) | |||||
| img = image.resize((sw, sh), Image.BILINEAR) | |||||
| img = np.asarray(img, 'float32') | |||||
| img = Variable( | |||||
| torch.FloatTensor(_preprocess(img)), volatile=True).to(device) | |||||
| output = net(img) | |||||
| probs = output[1].cpu().data.numpy()[0, 1, :, :] | |||||
| offsets = output[0].cpu().data.numpy() | |||||
| # probs: probability of a face at each sliding window | |||||
| # offsets: transformations to true bounding boxes | |||||
| boxes = _generate_bboxes(probs, offsets, scale, threshold) | |||||
| if len(boxes) == 0: | |||||
| return None | |||||
| keep = nms(boxes[:, 0:5], overlap_threshold=0.5) | |||||
| return boxes[keep] | |||||
| def _generate_bboxes(probs, offsets, scale, threshold): | |||||
| """Generate bounding boxes at places | |||||
| where there is probably a face. | |||||
| Arguments: | |||||
| probs: a float numpy array of shape [n, m]. | |||||
| offsets: a float numpy array of shape [1, 4, n, m]. | |||||
| scale: a float number, | |||||
| width and height of the image were scaled by this number. | |||||
| threshold: a float number. | |||||
| Returns: | |||||
| a float numpy array of shape [n_boxes, 9] | |||||
| """ | |||||
| # applying P-Net is equivalent, in some sense, to | |||||
| # moving 12x12 window with stride 2 | |||||
| stride = 2 | |||||
| cell_size = 12 | |||||
| # indices of boxes where there is probably a face | |||||
| inds = np.where(probs > threshold) | |||||
| if inds[0].size == 0: | |||||
| return np.array([]) | |||||
| # transformations of bounding boxes | |||||
| tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] | |||||
| # they are defined as: | |||||
| # w = x2 - x1 + 1 | |||||
| # h = y2 - y1 + 1 | |||||
| # x1_true = x1 + tx1*w | |||||
| # x2_true = x2 + tx2*w | |||||
| # y1_true = y1 + ty1*h | |||||
| # y2_true = y2 + ty2*h | |||||
| offsets = np.array([tx1, ty1, tx2, ty2]) | |||||
| score = probs[inds[0], inds[1]] | |||||
| # P-Net is applied to scaled images | |||||
| # so we need to rescale bounding boxes back | |||||
| bounding_boxes = np.vstack([ | |||||
| np.round((stride * inds[1] + 1.0) / scale), | |||||
| np.round((stride * inds[0] + 1.0) / scale), | |||||
| np.round((stride * inds[1] + 1.0 + cell_size) / scale), | |||||
| np.round((stride * inds[0] + 1.0 + cell_size) / scale), score, offsets | |||||
| ]) | |||||
| # why one is added? | |||||
| return bounding_boxes.T | |||||
| @@ -0,0 +1,160 @@ | |||||
| # The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch | |||||
| from collections import OrderedDict | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class Flatten(nn.Module): | |||||
| def __init__(self): | |||||
| super(Flatten, self).__init__() | |||||
| def forward(self, x): | |||||
| """ | |||||
| Arguments: | |||||
| x: a float tensor with shape [batch_size, c, h, w]. | |||||
| Returns: | |||||
| a float tensor with shape [batch_size, c*h*w]. | |||||
| """ | |||||
| # without this pretrained model isn't working | |||||
| x = x.transpose(3, 2).contiguous() | |||||
| return x.view(x.size(0), -1) | |||||
| class PNet(nn.Module): | |||||
| def __init__(self, model_path=None): | |||||
| super(PNet, self).__init__() | |||||
| # suppose we have input with size HxW, then | |||||
| # after first layer: H - 2, | |||||
| # after pool: ceil((H - 2)/2), | |||||
| # after second conv: ceil((H - 2)/2) - 2, | |||||
| # after last conv: ceil((H - 2)/2) - 4, | |||||
| # and the same for W | |||||
| self.features = nn.Sequential( | |||||
| OrderedDict([('conv1', nn.Conv2d(3, 10, 3, 1)), | |||||
| ('prelu1', nn.PReLU(10)), | |||||
| ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), | |||||
| ('conv2', nn.Conv2d(10, 16, 3, 1)), | |||||
| ('prelu2', nn.PReLU(16)), | |||||
| ('conv3', nn.Conv2d(16, 32, 3, 1)), | |||||
| ('prelu3', nn.PReLU(32))])) | |||||
| self.conv4_1 = nn.Conv2d(32, 2, 1, 1) | |||||
| self.conv4_2 = nn.Conv2d(32, 4, 1, 1) | |||||
| weights = np.load(model_path, allow_pickle=True)[()] | |||||
| for n, p in self.named_parameters(): | |||||
| p.data = torch.FloatTensor(weights[n]) | |||||
| def forward(self, x): | |||||
| """ | |||||
| Arguments: | |||||
| x: a float tensor with shape [batch_size, 3, h, w]. | |||||
| Returns: | |||||
| b: a float tensor with shape [batch_size, 4, h', w']. | |||||
| a: a float tensor with shape [batch_size, 2, h', w']. | |||||
| """ | |||||
| x = self.features(x) | |||||
| a = self.conv4_1(x) | |||||
| b = self.conv4_2(x) | |||||
| a = F.softmax(a) | |||||
| return b, a | |||||
| class RNet(nn.Module): | |||||
| def __init__(self, model_path=None): | |||||
| super(RNet, self).__init__() | |||||
| self.features = nn.Sequential( | |||||
| OrderedDict([('conv1', nn.Conv2d(3, 28, 3, 1)), | |||||
| ('prelu1', nn.PReLU(28)), | |||||
| ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), | |||||
| ('conv2', nn.Conv2d(28, 48, 3, 1)), | |||||
| ('prelu2', nn.PReLU(48)), | |||||
| ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), | |||||
| ('conv3', nn.Conv2d(48, 64, 2, 1)), | |||||
| ('prelu3', nn.PReLU(64)), ('flatten', Flatten()), | |||||
| ('conv4', nn.Linear(576, 128)), | |||||
| ('prelu4', nn.PReLU(128))])) | |||||
| self.conv5_1 = nn.Linear(128, 2) | |||||
| self.conv5_2 = nn.Linear(128, 4) | |||||
| weights = np.load(model_path, allow_pickle=True)[()] | |||||
| for n, p in self.named_parameters(): | |||||
| p.data = torch.FloatTensor(weights[n]) | |||||
| def forward(self, x): | |||||
| """ | |||||
| Arguments: | |||||
| x: a float tensor with shape [batch_size, 3, h, w]. | |||||
| Returns: | |||||
| b: a float tensor with shape [batch_size, 4]. | |||||
| a: a float tensor with shape [batch_size, 2]. | |||||
| """ | |||||
| x = self.features(x) | |||||
| a = self.conv5_1(x) | |||||
| b = self.conv5_2(x) | |||||
| a = F.softmax(a) | |||||
| return b, a | |||||
| class ONet(nn.Module): | |||||
| def __init__(self, model_path=None): | |||||
| super(ONet, self).__init__() | |||||
| self.features = nn.Sequential( | |||||
| OrderedDict([ | |||||
| ('conv1', nn.Conv2d(3, 32, 3, 1)), | |||||
| ('prelu1', nn.PReLU(32)), | |||||
| ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), | |||||
| ('conv2', nn.Conv2d(32, 64, 3, 1)), | |||||
| ('prelu2', nn.PReLU(64)), | |||||
| ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), | |||||
| ('conv3', nn.Conv2d(64, 64, 3, 1)), | |||||
| ('prelu3', nn.PReLU(64)), | |||||
| ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), | |||||
| ('conv4', nn.Conv2d(64, 128, 2, 1)), | |||||
| ('prelu4', nn.PReLU(128)), | |||||
| ('flatten', Flatten()), | |||||
| ('conv5', nn.Linear(1152, 256)), | |||||
| ('drop5', nn.Dropout(0.25)), | |||||
| ('prelu5', nn.PReLU(256)), | |||||
| ])) | |||||
| self.conv6_1 = nn.Linear(256, 2) | |||||
| self.conv6_2 = nn.Linear(256, 4) | |||||
| self.conv6_3 = nn.Linear(256, 10) | |||||
| weights = np.load(model_path, allow_pickle=True)[()] | |||||
| for n, p in self.named_parameters(): | |||||
| p.data = torch.FloatTensor(weights[n]) | |||||
| def forward(self, x): | |||||
| """ | |||||
| Arguments: | |||||
| x: a float tensor with shape [batch_size, 3, h, w]. | |||||
| Returns: | |||||
| c: a float tensor with shape [batch_size, 10]. | |||||
| b: a float tensor with shape [batch_size, 4]. | |||||
| a: a float tensor with shape [batch_size, 2]. | |||||
| """ | |||||
| x = self.features(x) | |||||
| a = self.conv6_1(x) | |||||
| b = self.conv6_2(x) | |||||
| c = self.conv6_3(x) | |||||
| a = F.softmax(a) | |||||
| return c, b, a | |||||
| @@ -51,6 +51,7 @@ if TYPE_CHECKING: | |||||
| from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | ||||
| from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | ||||
| from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | ||||
| from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -114,7 +115,8 @@ else: | |||||
| 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | ||||
| 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | ||||
| 'facial_expression_recognition_pipelin': | 'facial_expression_recognition_pipelin': | ||||
| ['FacialExpressionRecognitionPipeline'] | |||||
| ['FacialExpressionRecognitionPipeline'], | |||||
| 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,56 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.face_detection import MtcnnFaceDetector | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.face_detection, module_name=Pipelines.mtcnn_face_detection) | |||||
| class MtcnnFaceDetectionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a face detection pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| ckpt_path = osp.join(model, './weights') | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| device = torch.device( | |||||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||||
| detector = MtcnnFaceDetector(model_path=ckpt_path, device=device) | |||||
| self.detector = detector | |||||
| self.device = device | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.detector(input) | |||||
| assert result is not None | |||||
| bboxes = result[0][:, :4].tolist() | |||||
| scores = result[0][:, 4].tolist() | |||||
| lms = result[1].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: lms, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from PIL import Image | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_face_detection_result | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class MtcnnFaceDetectionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_manual_face-detection_mtcnn' | |||||
| def show_result(self, img_path, detection_result): | |||||
| img = draw_face_detection_result(img_path, detection_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/mtcnn_face_detection.jpg' | |||||
| img = Image.open(img_path) | |||||
| result_1 = face_detection(img_path) | |||||
| self.show_result(img_path, result_1) | |||||
| result_2 = face_detection(img) | |||||
| self.show_result(img_path, result_2) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||