Browse Source

[to #42322933] 新增ULFD人脸检测器

1. 完成Maas-cv CR标准 自查
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9957634
master
ly261666 yingda.chen 3 years ago
parent
commit
9cbf246a8c
19 changed files with 798 additions and 3 deletions
  1. +3
    -0
      data/test/images/ulfd_face_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +3
    -2
      modelscope/models/cv/face_detection/__init__.py
  4. +1
    -0
      modelscope/models/cv/face_detection/ulfd_slim/__init__.py
  5. +44
    -0
      modelscope/models/cv/face_detection/ulfd_slim/detection.py
  6. +0
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py
  7. +124
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py
  8. +49
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py
  9. +0
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py
  10. +18
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py
  11. +49
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py
  12. +124
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py
  13. +80
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py
  14. +129
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py
  15. +56
    -0
      modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py
  16. +3
    -1
      modelscope/pipelines/cv/__init__.py
  17. +56
    -0
      modelscope/pipelines/cv/ulfd_face_detection_pipeline.py
  18. +21
    -0
      modelscope/utils/cv/image_utils.py
  19. +36
    -0
      tests/pipelines/test_ulfd_face_detection.py

+ 3
- 0
data/test/images/ulfd_face_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
size 87228

+ 2
- 0
modelscope/metainfo.py View File

@@ -35,6 +35,7 @@ class Models(object):
fer = 'fer'
retinaface = 'retinaface'
shop_segmentation = 'shop-segmentation'
ulfd = 'ulfd'

# EasyCV models
yolox = 'YOLOX'
@@ -122,6 +123,7 @@ class Pipelines(object):
salient_detection = 'u2net-salient-detection'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
ulfd_face_detection = 'manual-face-detection-ulfd'
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
retina_face_detection = 'resnet50-face-detection-retinaface'
live_category = 'live-category'


+ 3
- 2
modelscope/models/cv/face_detection/__init__.py View File

@@ -5,10 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .retinaface import RetinaFaceDetection
from .ulfd_slim import UlfdFaceDetector
else:
_import_structure = {
'retinaface': ['RetinaFaceDetection'],
'ulfd_slim': ['UlfdFaceDetector'],
'retinaface': ['RetinaFaceDetection']
}

import sys


+ 1
- 0
modelscope/models/cv/face_detection/ulfd_slim/__init__.py View File

@@ -0,0 +1 @@
from .detection import UlfdFaceDetector

+ 44
- 0
modelscope/models/cv/face_detection/ulfd_slim/detection.py View File

@@ -0,0 +1,44 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import os

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from .vision.ssd.fd_config import define_img_size
from .vision.ssd.mb_tiny_fd import (create_mb_tiny_fd,
create_mb_tiny_fd_predictor)

define_img_size(640)


@MODELS.register_module(Tasks.face_detection, module_name=Models.ulfd)
class UlfdFaceDetector(TorchModel):

def __init__(self, model_path, device='cuda'):
super().__init__(model_path)
torch.set_grad_enabled(False)
cudnn.benchmark = True
self.model_path = model_path
self.device = device
self.net = create_mb_tiny_fd(2, is_test=True, device=device)
self.predictor = create_mb_tiny_fd_predictor(
self.net, candidate_size=1500, device=device)
self.net.load(model_path)
self.net = self.net.to(device)

def forward(self, input):
img_raw = input['img']
img = np.array(img_raw.cpu().detach())
img = img[:, :, ::-1]
prob_th = 0.85
keep_top_k = 750
boxes, labels, probs = self.predictor.predict(img, keep_top_k, prob_th)
return boxes, probs

+ 0
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py View File


+ 124
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py View File

@@ -0,0 +1,124 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import math

import torch


def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""

Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
_, indexes = scores.sort(descending=True)
indexes = indexes[:candidate_size]
while len(indexes) > 0:
current = indexes[0]
picked.append(current.item())
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[1:]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
current_box.unsqueeze(0),
)
indexes = indexes[iou <= iou_threshold]

return box_scores[picked, :]


def nms(box_scores,
nms_method=None,
score_threshold=None,
iou_threshold=None,
sigma=0.5,
top_k=-1,
candidate_size=200):
return hard_nms(
box_scores, iou_threshold, top_k, candidate_size=candidate_size)


def generate_priors(feature_map_list,
shrinkage_list,
image_size,
min_boxes,
clamp=True) -> torch.Tensor:
priors = []
for index in range(0, len(feature_map_list[0])):
scale_w = image_size[0] / shrinkage_list[0][index]
scale_h = image_size[1] / shrinkage_list[1][index]
for j in range(0, feature_map_list[1][index]):
for i in range(0, feature_map_list[0][index]):
x_center = (i + 0.5) / scale_w
y_center = (j + 0.5) / scale_h

for min_box in min_boxes[index]:
w = min_box / image_size[0]
h = min_box / image_size[1]
priors.append([x_center, y_center, w, h])
priors = torch.tensor(priors)
if clamp:
torch.clamp(priors, 0.0, 1.0, out=priors)
return priors


def convert_locations_to_boxes(locations, priors, center_variance,
size_variance):
# priors can have one dimension less.
if priors.dim() + 1 == locations.dim():
priors = priors.unsqueeze(0)
a = locations[..., :2] * center_variance * priors[...,
2:] + priors[..., :2]
b = torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:]

return torch.cat([a, b], dim=locations.dim() - 1)


def center_form_to_corner_form(locations):
a = locations[..., :2] - locations[..., 2:] / 2
b = locations[..., :2] + locations[..., 2:] / 2
return torch.cat([a, b], locations.dim() - 1)


def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.

Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:])

overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)


def area_of(left_top, right_bottom) -> torch.Tensor:
"""Compute the areas of rectangles given two corners.

Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.

Returns:
area (N): return the area.
"""
hw = torch.clamp(right_bottom - left_top, min=0.0)
return hw[..., 0] * hw[..., 1]

+ 49
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py View File

@@ -0,0 +1,49 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import torch.nn as nn
import torch.nn.functional as F


class Mb_Tiny(nn.Module):

def __init__(self, num_classes=2):
super(Mb_Tiny, self).__init__()
self.base_channel = 8 * 2

def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup), nn.ReLU(inplace=True))

def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)

self.model = nn.Sequential(
conv_bn(3, self.base_channel, 2), # 160*120
conv_dw(self.base_channel, self.base_channel * 2, 1),
conv_dw(self.base_channel * 2, self.base_channel * 2, 2), # 80*60
conv_dw(self.base_channel * 2, self.base_channel * 2, 1),
conv_dw(self.base_channel * 2, self.base_channel * 4, 2), # 40*30
conv_dw(self.base_channel * 4, self.base_channel * 4, 1),
conv_dw(self.base_channel * 4, self.base_channel * 4, 1),
conv_dw(self.base_channel * 4, self.base_channel * 4, 1),
conv_dw(self.base_channel * 4, self.base_channel * 8, 2), # 20*15
conv_dw(self.base_channel * 8, self.base_channel * 8, 1),
conv_dw(self.base_channel * 8, self.base_channel * 8, 1),
conv_dw(self.base_channel * 8, self.base_channel * 16, 2), # 10*8
conv_dw(self.base_channel * 16, self.base_channel * 16, 1))
self.fc = nn.Linear(1024, num_classes)

def forward(self, x):
x = self.model(x)
x = F.avg_pool2d(x, 7)
x = x.view(-1, 1024)
x = self.fc(x)
return x

+ 0
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py View File


+ 18
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py View File

@@ -0,0 +1,18 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
from ..transforms import Compose, Resize, SubtractMeans, ToTensor


class PredictionTransform:

def __init__(self, size, mean=0.0, std=1.0):
self.transform = Compose([
Resize(size),
SubtractMeans(mean), lambda img, boxes=None, labels=None:
(img / std, boxes, labels),
ToTensor()
])

def __call__(self, image):
image, _, _ = self.transform(image)
return image

+ 49
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py View File

@@ -0,0 +1,49 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import numpy as np

from ..box_utils import generate_priors

image_mean_test = image_mean = np.array([127, 127, 127])
image_std = 128.0
iou_threshold = 0.3
center_variance = 0.1
size_variance = 0.2

min_boxes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]]
shrinkage_list = []
image_size = [320, 240] # default input size 320*240
feature_map_w_h_list = [[40, 20, 10, 5], [30, 15, 8,
4]] # default feature map size
priors = []


def define_img_size(size):
global image_size, feature_map_w_h_list, priors
img_size_dict = {
128: [128, 96],
160: [160, 120],
320: [320, 240],
480: [480, 360],
640: [640, 480],
1280: [1280, 960]
}
image_size = img_size_dict[size]

feature_map_w_h_list_dict = {
128: [[16, 8, 4, 2], [12, 6, 3, 2]],
160: [[20, 10, 5, 3], [15, 8, 4, 2]],
320: [[40, 20, 10, 5], [30, 15, 8, 4]],
480: [[60, 30, 15, 8], [45, 23, 12, 6]],
640: [[80, 40, 20, 10], [60, 30, 15, 8]],
1280: [[160, 80, 40, 20], [120, 60, 30, 15]]
}
feature_map_w_h_list = feature_map_w_h_list_dict[size]

for i in range(0, len(image_size)):
item_list = []
for k in range(0, len(feature_map_w_h_list[i])):
item_list.append(image_size[i] / feature_map_w_h_list[i][k])
shrinkage_list.append(item_list)
priors = generate_priors(feature_map_w_h_list, shrinkage_list, image_size,
min_boxes)

+ 124
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py View File

@@ -0,0 +1,124 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
from torch.nn import Conv2d, ModuleList, ReLU, Sequential

from ..mb_tiny import Mb_Tiny
from . import fd_config as config
from .predictor import Predictor
from .ssd import SSD


def SeperableConv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0):
"""Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
"""
return Sequential(
Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding),
ReLU(),
Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1),
)


def create_mb_tiny_fd(num_classes, is_test=False, device='cuda'):
base_net = Mb_Tiny(2)
base_net_model = base_net.model # disable dropout layer

source_layer_indexes = [8, 11, 13]
extras = ModuleList([
Sequential(
Conv2d(
in_channels=base_net.base_channel * 16,
out_channels=base_net.base_channel * 4,
kernel_size=1), ReLU(),
SeperableConv2d(
in_channels=base_net.base_channel * 4,
out_channels=base_net.base_channel * 16,
kernel_size=3,
stride=2,
padding=1), ReLU())
])

regression_headers = ModuleList([
SeperableConv2d(
in_channels=base_net.base_channel * 4,
out_channels=3 * 4,
kernel_size=3,
padding=1),
SeperableConv2d(
in_channels=base_net.base_channel * 8,
out_channels=2 * 4,
kernel_size=3,
padding=1),
SeperableConv2d(
in_channels=base_net.base_channel * 16,
out_channels=2 * 4,
kernel_size=3,
padding=1),
Conv2d(
in_channels=base_net.base_channel * 16,
out_channels=3 * 4,
kernel_size=3,
padding=1)
])

classification_headers = ModuleList([
SeperableConv2d(
in_channels=base_net.base_channel * 4,
out_channels=3 * num_classes,
kernel_size=3,
padding=1),
SeperableConv2d(
in_channels=base_net.base_channel * 8,
out_channels=2 * num_classes,
kernel_size=3,
padding=1),
SeperableConv2d(
in_channels=base_net.base_channel * 16,
out_channels=2 * num_classes,
kernel_size=3,
padding=1),
Conv2d(
in_channels=base_net.base_channel * 16,
out_channels=3 * num_classes,
kernel_size=3,
padding=1)
])

return SSD(
num_classes,
base_net_model,
source_layer_indexes,
extras,
classification_headers,
regression_headers,
is_test=is_test,
config=config,
device=device)


def create_mb_tiny_fd_predictor(net,
candidate_size=200,
nms_method=None,
sigma=0.5,
device=None):
predictor = Predictor(
net,
config.image_size,
config.image_mean_test,
config.image_std,
nms_method=nms_method,
iou_threshold=config.iou_threshold,
candidate_size=candidate_size,
sigma=sigma,
device=device)
return predictor

+ 80
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py View File

@@ -0,0 +1,80 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import torch

from .. import box_utils
from .data_preprocessing import PredictionTransform


class Predictor:

def __init__(self,
net,
size,
mean=0.0,
std=1.0,
nms_method=None,
iou_threshold=0.3,
filter_threshold=0.85,
candidate_size=200,
sigma=0.5,
device=None):
self.net = net
self.transform = PredictionTransform(size, mean, std)
self.iou_threshold = iou_threshold
self.filter_threshold = filter_threshold
self.candidate_size = candidate_size
self.nms_method = nms_method

self.sigma = sigma
if device:
self.device = device
else:
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')

self.net.to(self.device)
self.net.eval()

def predict(self, image, top_k=-1, prob_threshold=None):
height, width, _ = image.shape
image = self.transform(image)
images = image.unsqueeze(0)
images = images.to(self.device)
with torch.no_grad():
for i in range(1):
scores, boxes = self.net.forward(images)
boxes = boxes[0]
scores = scores[0]
if not prob_threshold:
prob_threshold = self.filter_threshold
# this version of nms is slower on GPU, so we move data to CPU.
picked_box_probs = []
picked_labels = []
for class_index in range(1, scores.size(1)):
probs = scores[:, class_index]
mask = probs > prob_threshold
probs = probs[mask]
if probs.size(0) == 0:
continue
subset_boxes = boxes[mask, :]
box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1)
box_probs = box_utils.nms(
box_probs,
self.nms_method,
score_threshold=prob_threshold,
iou_threshold=self.iou_threshold,
sigma=self.sigma,
top_k=top_k,
candidate_size=self.candidate_size)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.size(0))
if not picked_box_probs:
return torch.tensor([]), torch.tensor([]), torch.tensor([])
picked_box_probs = torch.cat(picked_box_probs)
picked_box_probs[:, 0] *= width
picked_box_probs[:, 1] *= height
picked_box_probs[:, 2] *= width
picked_box_probs[:, 3] *= height
return picked_box_probs[:, :4], torch.tensor(
picked_labels), picked_box_probs[:, 4]

+ 129
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py View File

@@ -0,0 +1,129 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
from collections import namedtuple
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import box_utils

GraphPath = namedtuple('GraphPath', ['s0', 'name', 's1'])


class SSD(nn.Module):

def __init__(self,
num_classes: int,
base_net: nn.ModuleList,
source_layer_indexes: List[int],
extras: nn.ModuleList,
classification_headers: nn.ModuleList,
regression_headers: nn.ModuleList,
is_test=False,
config=None,
device=None):
"""Compose a SSD model using the given components.
"""
super(SSD, self).__init__()

self.num_classes = num_classes
self.base_net = base_net
self.source_layer_indexes = source_layer_indexes
self.extras = extras
self.classification_headers = classification_headers
self.regression_headers = regression_headers
self.is_test = is_test
self.config = config

# register layers in source_layer_indexes by adding them to a module list
self.source_layer_add_ons = nn.ModuleList([
t[1] for t in source_layer_indexes
if isinstance(t, tuple) and not isinstance(t, GraphPath)
])
if device:
self.device = device
else:
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
if is_test:
self.config = config
self.priors = config.priors.to(self.device)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
confidences = []
locations = []
start_layer_index = 0
header_index = 0
end_layer_index = 0
for end_layer_index in self.source_layer_indexes:
if isinstance(end_layer_index, GraphPath):
path = end_layer_index
end_layer_index = end_layer_index.s0
added_layer = None
elif isinstance(end_layer_index, tuple):
added_layer = end_layer_index[1]
end_layer_index = end_layer_index[0]
path = None
else:
added_layer = None
path = None
for layer in self.base_net[start_layer_index:end_layer_index]:
x = layer(x)
if added_layer:
y = added_layer(x)
else:
y = x
if path:
sub = getattr(self.base_net[end_layer_index], path.name)
for layer in sub[:path.s1]:
x = layer(x)
y = x
for layer in sub[path.s1:]:
x = layer(x)
end_layer_index += 1
start_layer_index = end_layer_index
confidence, location = self.compute_header(header_index, y)
header_index += 1
confidences.append(confidence)
locations.append(location)

for layer in self.base_net[end_layer_index:]:
x = layer(x)

for layer in self.extras:
x = layer(x)
confidence, location = self.compute_header(header_index, x)
header_index += 1
confidences.append(confidence)
locations.append(location)

confidences = torch.cat(confidences, 1)
locations = torch.cat(locations, 1)

if self.is_test:
confidences = F.softmax(confidences, dim=2)
boxes = box_utils.convert_locations_to_boxes(
locations, self.priors, self.config.center_variance,
self.config.size_variance)
boxes = box_utils.center_form_to_corner_form(boxes)
return confidences, boxes
else:
return confidences, locations

def compute_header(self, i, x):
confidence = self.classification_headers[i](x)
confidence = confidence.permute(0, 2, 3, 1).contiguous()
confidence = confidence.view(confidence.size(0), -1, self.num_classes)

location = self.regression_headers[i](x)
location = location.permute(0, 2, 3, 1).contiguous()
location = location.view(location.size(0), -1, 4)

return confidence, location

def load(self, model):
self.load_state_dict(
torch.load(model, map_location=lambda storage, loc: storage))

+ 56
- 0
modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py View File

@@ -0,0 +1,56 @@
# The implementation is based on ULFD, available at
# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
import types

import cv2
import numpy as np
import torch
from numpy import random


class Compose(object):
"""Composes several augmentations together.
Args:
transforms (List[Transform]): list of transforms to compose.
Example:
>>> augmentations.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, img, boxes=None, labels=None):
for t in self.transforms:
img, boxes, labels = t(img, boxes, labels)
return img, boxes, labels


class SubtractMeans(object):

def __init__(self, mean):
self.mean = np.array(mean, dtype=np.float32)

def __call__(self, image, boxes=None, labels=None):
image = image.astype(np.float32)
image -= self.mean
return image.astype(np.float32), boxes, labels


class Resize(object):

def __init__(self, size=(300, 300)):
self.size = size

def __call__(self, image, boxes=None, labels=None):
image = cv2.resize(image, (self.size[0], self.size[1]))
return image, boxes, labels


class ToTensor(object):

def __call__(self, cvimage, boxes=None, labels=None):
return torch.from_numpy(cvimage.astype(np.float32)).permute(
2, 0, 1), boxes, labels

+ 3
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -46,8 +46,9 @@ if TYPE_CHECKING:
from .virtual_try_on_pipeline import VirtualTryonPipeline
from .shop_segmentation_pipleline import ShopSegmentationPipeline
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline
from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline
from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline

@@ -110,6 +111,7 @@ else:
['TextDrivenSegmentationPipeline'],
'movie_scene_segmentation_pipeline':
['MovieSceneSegmentationPipeline'],
'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'],
'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'],
'facial_expression_recognition_pipelin':
['FacialExpressionRecognitionPipeline']


+ 56
- 0
modelscope/pipelines/cv/ulfd_face_detection_pipeline.py View File

@@ -0,0 +1,56 @@
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_detection import UlfdFaceDetector
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_detection, module_name=Pipelines.ulfd_face_detection)
class UlfdFaceDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a face detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {ckpt_path}')
detector = UlfdFaceDetector(model_path=ckpt_path, device=self.device)
self.detector = detector
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float32)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
result = self.detector(input)
assert result is not None
bboxes = result[0].tolist()
scores = result[1].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: None,
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 21
- 0
modelscope/utils/cv/image_utils.py View File

@@ -89,6 +89,27 @@ def draw_keypoints(output, original_image):
return image


def draw_face_detection_no_lm_result(img_path, detection_result):
bboxes = np.array(detection_result[OutputKeys.BOXES])
scores = np.array(detection_result[OutputKeys.SCORES])
img = cv2.imread(img_path)
assert img is not None, f"Can't read img: {img_path}"
for i in range(len(scores)):
bbox = bboxes[i].astype(np.int32)
x1, y1, x2, y2 = bbox
score = scores[i]
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.putText(
img,
f'{score:.2f}', (x1, y2),
1,
1.0, (0, 255, 0),
thickness=1,
lineType=8)
print(f'Found {len(scores)} faces')
return img


def draw_facial_expression_result(img_path, facial_expression_result):
label_idx = facial_expression_result[OutputKeys.LABELS]
map_list = [


+ 36
- 0
tests/pipelines/test_ulfd_face_detection.py View File

@@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2
import numpy as np

from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result
from modelscope.utils.test_utils import test_level


class UlfdFaceDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_manual_face-detection_ulfd'

def show_result(self, img_path, detection_result):
img = draw_face_detection_no_lm_result(img_path, detection_result)
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
img_path = 'data/test/images/ulfd_face_detection.jpg'

result = face_detection(img_path)
self.show_result(img_path, result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save