From 9cbf246a8c4a09be20d5a32cea728f2250faa305 Mon Sep 17 00:00:00 2001 From: ly261666 Date: Tue, 6 Sep 2022 10:02:49 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=20=E6=96=B0=E5=A2=9EULFD?= =?UTF-8?q?=E4=BA=BA=E8=84=B8=E6=A3=80=E6=B5=8B=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 完成Maas-cv CR标准 自查 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9957634 --- data/test/images/ulfd_face_detection.jpg | 3 + modelscope/metainfo.py | 2 + .../models/cv/face_detection/__init__.py | 5 +- .../cv/face_detection/ulfd_slim/__init__.py | 1 + .../cv/face_detection/ulfd_slim/detection.py | 44 ++++++ .../ulfd_slim/vision/__init__.py | 0 .../ulfd_slim/vision/box_utils.py | 124 +++++++++++++++++ .../ulfd_slim/vision/mb_tiny.py | 49 +++++++ .../ulfd_slim/vision/ssd/__init__.py | 0 .../vision/ssd/data_preprocessing.py | 18 +++ .../ulfd_slim/vision/ssd/fd_config.py | 49 +++++++ .../ulfd_slim/vision/ssd/mb_tiny_fd.py | 124 +++++++++++++++++ .../ulfd_slim/vision/ssd/predictor.py | 80 +++++++++++ .../ulfd_slim/vision/ssd/ssd.py | 129 ++++++++++++++++++ .../ulfd_slim/vision/transforms.py | 56 ++++++++ modelscope/pipelines/cv/__init__.py | 4 +- .../cv/ulfd_face_detection_pipeline.py | 56 ++++++++ modelscope/utils/cv/image_utils.py | 21 +++ tests/pipelines/test_ulfd_face_detection.py | 36 +++++ 19 files changed, 798 insertions(+), 3 deletions(-) create mode 100644 data/test/images/ulfd_face_detection.jpg create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/__init__.py create mode 100755 modelscope/models/cv/face_detection/ulfd_slim/detection.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py create mode 100644 modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py create mode 100644 modelscope/pipelines/cv/ulfd_face_detection_pipeline.py create mode 100644 tests/pipelines/test_ulfd_face_detection.py diff --git a/data/test/images/ulfd_face_detection.jpg b/data/test/images/ulfd_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/ulfd_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 792bd708..22c2d99e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -35,6 +35,7 @@ class Models(object): fer = 'fer' retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' + ulfd = 'ulfd' # EasyCV models yolox = 'YOLOX' @@ -122,6 +123,7 @@ class Pipelines(object): salient_detection = 'u2net-salient-detection' image_classification = 'image-classification' face_detection = 'resnet-face-detection-scrfd10gkps' + ulfd_face_detection = 'manual-face-detection-ulfd' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' live_category = 'live-category' diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index a3c47164..63ff1b83 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -5,10 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .retinaface import RetinaFaceDetection - + from .ulfd_slim import UlfdFaceDetector else: _import_structure = { - 'retinaface': ['RetinaFaceDetection'], + 'ulfd_slim': ['UlfdFaceDetector'], + 'retinaface': ['RetinaFaceDetection'] } import sys diff --git a/modelscope/models/cv/face_detection/ulfd_slim/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py new file mode 100644 index 00000000..41a2226a --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py @@ -0,0 +1 @@ +from .detection import UlfdFaceDetector diff --git a/modelscope/models/cv/face_detection/ulfd_slim/detection.py b/modelscope/models/cv/face_detection/ulfd_slim/detection.py new file mode 100755 index 00000000..c0e2da6e --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/detection.py @@ -0,0 +1,44 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .vision.ssd.fd_config import define_img_size +from .vision.ssd.mb_tiny_fd import (create_mb_tiny_fd, + create_mb_tiny_fd_predictor) + +define_img_size(640) + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.ulfd) +class UlfdFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.net = create_mb_tiny_fd(2, is_test=True, device=device) + self.predictor = create_mb_tiny_fd_predictor( + self.net, candidate_size=1500, device=device) + self.net.load(model_path) + self.net = self.net.to(device) + + def forward(self, input): + img_raw = input['img'] + img = np.array(img_raw.cpu().detach()) + img = img[:, :, ::-1] + prob_th = 0.85 + keep_top_k = 750 + boxes, labels, probs = self.predictor.predict(img, keep_top_k, prob_th) + return boxes, probs diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py new file mode 100644 index 00000000..46d3b890 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py @@ -0,0 +1,124 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import math + +import torch + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + _, indexes = scores.sort(descending=True) + indexes = indexes[:candidate_size] + while len(indexes) > 0: + current = indexes[0] + picked.append(current.item()) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[1:] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + current_box.unsqueeze(0), + ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def nms(box_scores, + nms_method=None, + score_threshold=None, + iou_threshold=None, + sigma=0.5, + top_k=-1, + candidate_size=200): + return hard_nms( + box_scores, iou_threshold, top_k, candidate_size=candidate_size) + + +def generate_priors(feature_map_list, + shrinkage_list, + image_size, + min_boxes, + clamp=True) -> torch.Tensor: + priors = [] + for index in range(0, len(feature_map_list[0])): + scale_w = image_size[0] / shrinkage_list[0][index] + scale_h = image_size[1] / shrinkage_list[1][index] + for j in range(0, feature_map_list[1][index]): + for i in range(0, feature_map_list[0][index]): + x_center = (i + 0.5) / scale_w + y_center = (j + 0.5) / scale_h + + for min_box in min_boxes[index]: + w = min_box / image_size[0] + h = min_box / image_size[1] + priors.append([x_center, y_center, w, h]) + priors = torch.tensor(priors) + if clamp: + torch.clamp(priors, 0.0, 1.0, out=priors) + return priors + + +def convert_locations_to_boxes(locations, priors, center_variance, + size_variance): + # priors can have one dimension less. + if priors.dim() + 1 == locations.dim(): + priors = priors.unsqueeze(0) + a = locations[..., :2] * center_variance * priors[..., + 2:] + priors[..., :2] + b = torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:] + + return torch.cat([a, b], dim=locations.dim() - 1) + + +def center_form_to_corner_form(locations): + a = locations[..., :2] - locations[..., 2:] / 2 + b = locations[..., :2] + locations[..., 2:] / 2 + return torch.cat([a, b], locations.dim() - 1) + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom) -> torch.Tensor: + """Compute the areas of rectangles given two corners. + + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + + Returns: + area (N): return the area. + """ + hw = torch.clamp(right_bottom - left_top, min=0.0) + return hw[..., 0] * hw[..., 1] diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py new file mode 100644 index 00000000..8bbcef41 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py @@ -0,0 +1,49 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import torch.nn as nn +import torch.nn.functional as F + + +class Mb_Tiny(nn.Module): + + def __init__(self, num_classes=2): + super(Mb_Tiny, self).__init__() + self.base_channel = 8 * 2 + + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + self.model = nn.Sequential( + conv_bn(3, self.base_channel, 2), # 160*120 + conv_dw(self.base_channel, self.base_channel * 2, 1), + conv_dw(self.base_channel * 2, self.base_channel * 2, 2), # 80*60 + conv_dw(self.base_channel * 2, self.base_channel * 2, 1), + conv_dw(self.base_channel * 2, self.base_channel * 4, 2), # 40*30 + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 8, 2), # 20*15 + conv_dw(self.base_channel * 8, self.base_channel * 8, 1), + conv_dw(self.base_channel * 8, self.base_channel * 8, 1), + conv_dw(self.base_channel * 8, self.base_channel * 16, 2), # 10*8 + conv_dw(self.base_channel * 16, self.base_channel * 16, 1)) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + x = self.model(x) + x = F.avg_pool2d(x, 7) + x = x.view(-1, 1024) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py new file mode 100644 index 00000000..9251d67f --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py @@ -0,0 +1,18 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from ..transforms import Compose, Resize, SubtractMeans, ToTensor + + +class PredictionTransform: + + def __init__(self, size, mean=0.0, std=1.0): + self.transform = Compose([ + Resize(size), + SubtractMeans(mean), lambda img, boxes=None, labels=None: + (img / std, boxes, labels), + ToTensor() + ]) + + def __call__(self, image): + image, _, _ = self.transform(image) + return image diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py new file mode 100644 index 00000000..495a2fcd --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py @@ -0,0 +1,49 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import numpy as np + +from ..box_utils import generate_priors + +image_mean_test = image_mean = np.array([127, 127, 127]) +image_std = 128.0 +iou_threshold = 0.3 +center_variance = 0.1 +size_variance = 0.2 + +min_boxes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]] +shrinkage_list = [] +image_size = [320, 240] # default input size 320*240 +feature_map_w_h_list = [[40, 20, 10, 5], [30, 15, 8, + 4]] # default feature map size +priors = [] + + +def define_img_size(size): + global image_size, feature_map_w_h_list, priors + img_size_dict = { + 128: [128, 96], + 160: [160, 120], + 320: [320, 240], + 480: [480, 360], + 640: [640, 480], + 1280: [1280, 960] + } + image_size = img_size_dict[size] + + feature_map_w_h_list_dict = { + 128: [[16, 8, 4, 2], [12, 6, 3, 2]], + 160: [[20, 10, 5, 3], [15, 8, 4, 2]], + 320: [[40, 20, 10, 5], [30, 15, 8, 4]], + 480: [[60, 30, 15, 8], [45, 23, 12, 6]], + 640: [[80, 40, 20, 10], [60, 30, 15, 8]], + 1280: [[160, 80, 40, 20], [120, 60, 30, 15]] + } + feature_map_w_h_list = feature_map_w_h_list_dict[size] + + for i in range(0, len(image_size)): + item_list = [] + for k in range(0, len(feature_map_w_h_list[i])): + item_list.append(image_size[i] / feature_map_w_h_list[i][k]) + shrinkage_list.append(item_list) + priors = generate_priors(feature_map_w_h_list, shrinkage_list, image_size, + min_boxes) diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py new file mode 100644 index 00000000..91ed268d --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py @@ -0,0 +1,124 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from torch.nn import Conv2d, ModuleList, ReLU, Sequential + +from ..mb_tiny import Mb_Tiny +from . import fd_config as config +from .predictor import Predictor +from .ssd import SSD + + +def SeperableConv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0): + """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. + """ + return Sequential( + Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding), + ReLU(), + Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1), + ) + + +def create_mb_tiny_fd(num_classes, is_test=False, device='cuda'): + base_net = Mb_Tiny(2) + base_net_model = base_net.model # disable dropout layer + + source_layer_indexes = [8, 11, 13] + extras = ModuleList([ + Sequential( + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=base_net.base_channel * 4, + kernel_size=1), ReLU(), + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=base_net.base_channel * 16, + kernel_size=3, + stride=2, + padding=1), ReLU()) + ]) + + regression_headers = ModuleList([ + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=3 * 4, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 8, + out_channels=2 * 4, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 16, + out_channels=2 * 4, + kernel_size=3, + padding=1), + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=3 * 4, + kernel_size=3, + padding=1) + ]) + + classification_headers = ModuleList([ + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=3 * num_classes, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 8, + out_channels=2 * num_classes, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 16, + out_channels=2 * num_classes, + kernel_size=3, + padding=1), + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=3 * num_classes, + kernel_size=3, + padding=1) + ]) + + return SSD( + num_classes, + base_net_model, + source_layer_indexes, + extras, + classification_headers, + regression_headers, + is_test=is_test, + config=config, + device=device) + + +def create_mb_tiny_fd_predictor(net, + candidate_size=200, + nms_method=None, + sigma=0.5, + device=None): + predictor = Predictor( + net, + config.image_size, + config.image_mean_test, + config.image_std, + nms_method=nms_method, + iou_threshold=config.iou_threshold, + candidate_size=candidate_size, + sigma=sigma, + device=device) + return predictor diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py new file mode 100644 index 00000000..f71820a5 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py @@ -0,0 +1,80 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import torch + +from .. import box_utils +from .data_preprocessing import PredictionTransform + + +class Predictor: + + def __init__(self, + net, + size, + mean=0.0, + std=1.0, + nms_method=None, + iou_threshold=0.3, + filter_threshold=0.85, + candidate_size=200, + sigma=0.5, + device=None): + self.net = net + self.transform = PredictionTransform(size, mean, std) + self.iou_threshold = iou_threshold + self.filter_threshold = filter_threshold + self.candidate_size = candidate_size + self.nms_method = nms_method + + self.sigma = sigma + if device: + self.device = device + else: + self.device = torch.device( + 'cuda:0' if torch.cuda.is_available() else 'cpu') + + self.net.to(self.device) + self.net.eval() + + def predict(self, image, top_k=-1, prob_threshold=None): + height, width, _ = image.shape + image = self.transform(image) + images = image.unsqueeze(0) + images = images.to(self.device) + with torch.no_grad(): + for i in range(1): + scores, boxes = self.net.forward(images) + boxes = boxes[0] + scores = scores[0] + if not prob_threshold: + prob_threshold = self.filter_threshold + # this version of nms is slower on GPU, so we move data to CPU. + picked_box_probs = [] + picked_labels = [] + for class_index in range(1, scores.size(1)): + probs = scores[:, class_index] + mask = probs > prob_threshold + probs = probs[mask] + if probs.size(0) == 0: + continue + subset_boxes = boxes[mask, :] + box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) + box_probs = box_utils.nms( + box_probs, + self.nms_method, + score_threshold=prob_threshold, + iou_threshold=self.iou_threshold, + sigma=self.sigma, + top_k=top_k, + candidate_size=self.candidate_size) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.size(0)) + if not picked_box_probs: + return torch.tensor([]), torch.tensor([]), torch.tensor([]) + picked_box_probs = torch.cat(picked_box_probs) + picked_box_probs[:, 0] *= width + picked_box_probs[:, 1] *= height + picked_box_probs[:, 2] *= width + picked_box_probs[:, 3] *= height + return picked_box_probs[:, :4], torch.tensor( + picked_labels), picked_box_probs[:, 4] diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py new file mode 100644 index 00000000..08ff93a4 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py @@ -0,0 +1,129 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from collections import namedtuple +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .. import box_utils + +GraphPath = namedtuple('GraphPath', ['s0', 'name', 's1']) + + +class SSD(nn.Module): + + def __init__(self, + num_classes: int, + base_net: nn.ModuleList, + source_layer_indexes: List[int], + extras: nn.ModuleList, + classification_headers: nn.ModuleList, + regression_headers: nn.ModuleList, + is_test=False, + config=None, + device=None): + """Compose a SSD model using the given components. + """ + super(SSD, self).__init__() + + self.num_classes = num_classes + self.base_net = base_net + self.source_layer_indexes = source_layer_indexes + self.extras = extras + self.classification_headers = classification_headers + self.regression_headers = regression_headers + self.is_test = is_test + self.config = config + + # register layers in source_layer_indexes by adding them to a module list + self.source_layer_add_ons = nn.ModuleList([ + t[1] for t in source_layer_indexes + if isinstance(t, tuple) and not isinstance(t, GraphPath) + ]) + if device: + self.device = device + else: + self.device = torch.device( + 'cuda:0' if torch.cuda.is_available() else 'cpu') + if is_test: + self.config = config + self.priors = config.priors.to(self.device) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + confidences = [] + locations = [] + start_layer_index = 0 + header_index = 0 + end_layer_index = 0 + for end_layer_index in self.source_layer_indexes: + if isinstance(end_layer_index, GraphPath): + path = end_layer_index + end_layer_index = end_layer_index.s0 + added_layer = None + elif isinstance(end_layer_index, tuple): + added_layer = end_layer_index[1] + end_layer_index = end_layer_index[0] + path = None + else: + added_layer = None + path = None + for layer in self.base_net[start_layer_index:end_layer_index]: + x = layer(x) + if added_layer: + y = added_layer(x) + else: + y = x + if path: + sub = getattr(self.base_net[end_layer_index], path.name) + for layer in sub[:path.s1]: + x = layer(x) + y = x + for layer in sub[path.s1:]: + x = layer(x) + end_layer_index += 1 + start_layer_index = end_layer_index + confidence, location = self.compute_header(header_index, y) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + for layer in self.base_net[end_layer_index:]: + x = layer(x) + + for layer in self.extras: + x = layer(x) + confidence, location = self.compute_header(header_index, x) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + confidences = torch.cat(confidences, 1) + locations = torch.cat(locations, 1) + + if self.is_test: + confidences = F.softmax(confidences, dim=2) + boxes = box_utils.convert_locations_to_boxes( + locations, self.priors, self.config.center_variance, + self.config.size_variance) + boxes = box_utils.center_form_to_corner_form(boxes) + return confidences, boxes + else: + return confidences, locations + + def compute_header(self, i, x): + confidence = self.classification_headers[i](x) + confidence = confidence.permute(0, 2, 3, 1).contiguous() + confidence = confidence.view(confidence.size(0), -1, self.num_classes) + + location = self.regression_headers[i](x) + location = location.permute(0, 2, 3, 1).contiguous() + location = location.view(location.size(0), -1, 4) + + return confidence, location + + def load(self, model): + self.load_state_dict( + torch.load(model, map_location=lambda storage, loc: storage)) diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py new file mode 100644 index 00000000..7c5331f1 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py @@ -0,0 +1,56 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import types + +import cv2 +import numpy as np +import torch +from numpy import random + + +class Compose(object): + """Composes several augmentations together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> augmentations.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, boxes=None, labels=None): + for t in self.transforms: + img, boxes, labels = t(img, boxes, labels) + return img, boxes, labels + + +class SubtractMeans(object): + + def __init__(self, mean): + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + image = image.astype(np.float32) + image -= self.mean + return image.astype(np.float32), boxes, labels + + +class Resize(object): + + def __init__(self, size=(300, 300)): + self.size = size + + def __call__(self, image, boxes=None, labels=None): + image = cv2.resize(image, (self.size[0], self.size[1])) + return image, boxes, labels + + +class ToTensor(object): + + def __call__(self, cvimage, boxes=None, labels=None): + return torch.from_numpy(cvimage.astype(np.float32)).permute( + 2, 0, 1), boxes, labels diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 72a225ff..02682fa0 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -46,8 +46,9 @@ if TYPE_CHECKING: from .virtual_try_on_pipeline import VirtualTryonPipeline from .shop_segmentation_pipleline import ShopSegmentationPipeline from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline - from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline + from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline + from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline @@ -110,6 +111,7 @@ else: ['TextDrivenSegmentationPipeline'], 'movie_scene_segmentation_pipeline': ['MovieSceneSegmentationPipeline'], + 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], 'facial_expression_recognition_pipelin': ['FacialExpressionRecognitionPipeline'] diff --git a/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py b/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py new file mode 100644 index 00000000..1263082b --- /dev/null +++ b/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py @@ -0,0 +1,56 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import UlfdFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.ulfd_face_detection) +class UlfdFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = UlfdFaceDetector(model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0].tolist() + scores = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: None, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index cb07ba1a..6175a53f 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -89,6 +89,27 @@ def draw_keypoints(output, original_image): return image +def draw_face_detection_no_lm_result(img_path, detection_result): + bboxes = np.array(detection_result[OutputKeys.BOXES]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + x1, y1, x2, y2 = bbox + score = scores[i] + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + print(f'Found {len(scores)} faces') + return img + + def draw_facial_expression_result(img_path, facial_expression_result): label_idx = facial_expression_result[OutputKeys.LABELS] map_list = [ diff --git a/tests/pipelines/test_ulfd_face_detection.py b/tests/pipelines/test_ulfd_face_detection.py new file mode 100644 index 00000000..0ffa688c --- /dev/null +++ b/tests/pipelines/test_ulfd_face_detection.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 +import numpy as np + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result +from modelscope.utils.test_utils import test_level + + +class UlfdFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_manual_face-detection_ulfd' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_no_lm_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/ulfd_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main()