1. 完成Maas-cv CR标准 自查
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9957634
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 | |||
| size 87228 | |||
| @@ -35,6 +35,7 @@ class Models(object): | |||
| fer = 'fer' | |||
| retinaface = 'retinaface' | |||
| shop_segmentation = 'shop-segmentation' | |||
| ulfd = 'ulfd' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| @@ -122,6 +123,7 @@ class Pipelines(object): | |||
| salient_detection = 'u2net-salient-detection' | |||
| image_classification = 'image-classification' | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| ulfd_face_detection = 'manual-face-detection-ulfd' | |||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||
| live_category = 'live-category' | |||
| @@ -5,10 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .retinaface import RetinaFaceDetection | |||
| from .ulfd_slim import UlfdFaceDetector | |||
| else: | |||
| _import_structure = { | |||
| 'retinaface': ['RetinaFaceDetection'], | |||
| 'ulfd_slim': ['UlfdFaceDetector'], | |||
| 'retinaface': ['RetinaFaceDetection'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1 @@ | |||
| from .detection import UlfdFaceDetector | |||
| @@ -0,0 +1,44 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.backends.cudnn as cudnn | |||
| import torch.nn.functional as F | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from .vision.ssd.fd_config import define_img_size | |||
| from .vision.ssd.mb_tiny_fd import (create_mb_tiny_fd, | |||
| create_mb_tiny_fd_predictor) | |||
| define_img_size(640) | |||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.ulfd) | |||
| class UlfdFaceDetector(TorchModel): | |||
| def __init__(self, model_path, device='cuda'): | |||
| super().__init__(model_path) | |||
| torch.set_grad_enabled(False) | |||
| cudnn.benchmark = True | |||
| self.model_path = model_path | |||
| self.device = device | |||
| self.net = create_mb_tiny_fd(2, is_test=True, device=device) | |||
| self.predictor = create_mb_tiny_fd_predictor( | |||
| self.net, candidate_size=1500, device=device) | |||
| self.net.load(model_path) | |||
| self.net = self.net.to(device) | |||
| def forward(self, input): | |||
| img_raw = input['img'] | |||
| img = np.array(img_raw.cpu().detach()) | |||
| img = img[:, :, ::-1] | |||
| prob_th = 0.85 | |||
| keep_top_k = 750 | |||
| boxes, labels, probs = self.predictor.predict(img, keep_top_k, prob_th) | |||
| return boxes, probs | |||
| @@ -0,0 +1,124 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import math | |||
| import torch | |||
| def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): | |||
| """ | |||
| Args: | |||
| box_scores (N, 5): boxes in corner-form and probabilities. | |||
| iou_threshold: intersection over union threshold. | |||
| top_k: keep top_k results. If k <= 0, keep all the results. | |||
| candidate_size: only consider the candidates with the highest scores. | |||
| Returns: | |||
| picked: a list of indexes of the kept boxes | |||
| """ | |||
| scores = box_scores[:, -1] | |||
| boxes = box_scores[:, :-1] | |||
| picked = [] | |||
| _, indexes = scores.sort(descending=True) | |||
| indexes = indexes[:candidate_size] | |||
| while len(indexes) > 0: | |||
| current = indexes[0] | |||
| picked.append(current.item()) | |||
| if 0 < top_k == len(picked) or len(indexes) == 1: | |||
| break | |||
| current_box = boxes[current, :] | |||
| indexes = indexes[1:] | |||
| rest_boxes = boxes[indexes, :] | |||
| iou = iou_of( | |||
| rest_boxes, | |||
| current_box.unsqueeze(0), | |||
| ) | |||
| indexes = indexes[iou <= iou_threshold] | |||
| return box_scores[picked, :] | |||
| def nms(box_scores, | |||
| nms_method=None, | |||
| score_threshold=None, | |||
| iou_threshold=None, | |||
| sigma=0.5, | |||
| top_k=-1, | |||
| candidate_size=200): | |||
| return hard_nms( | |||
| box_scores, iou_threshold, top_k, candidate_size=candidate_size) | |||
| def generate_priors(feature_map_list, | |||
| shrinkage_list, | |||
| image_size, | |||
| min_boxes, | |||
| clamp=True) -> torch.Tensor: | |||
| priors = [] | |||
| for index in range(0, len(feature_map_list[0])): | |||
| scale_w = image_size[0] / shrinkage_list[0][index] | |||
| scale_h = image_size[1] / shrinkage_list[1][index] | |||
| for j in range(0, feature_map_list[1][index]): | |||
| for i in range(0, feature_map_list[0][index]): | |||
| x_center = (i + 0.5) / scale_w | |||
| y_center = (j + 0.5) / scale_h | |||
| for min_box in min_boxes[index]: | |||
| w = min_box / image_size[0] | |||
| h = min_box / image_size[1] | |||
| priors.append([x_center, y_center, w, h]) | |||
| priors = torch.tensor(priors) | |||
| if clamp: | |||
| torch.clamp(priors, 0.0, 1.0, out=priors) | |||
| return priors | |||
| def convert_locations_to_boxes(locations, priors, center_variance, | |||
| size_variance): | |||
| # priors can have one dimension less. | |||
| if priors.dim() + 1 == locations.dim(): | |||
| priors = priors.unsqueeze(0) | |||
| a = locations[..., :2] * center_variance * priors[..., | |||
| 2:] + priors[..., :2] | |||
| b = torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:] | |||
| return torch.cat([a, b], dim=locations.dim() - 1) | |||
| def center_form_to_corner_form(locations): | |||
| a = locations[..., :2] - locations[..., 2:] / 2 | |||
| b = locations[..., :2] + locations[..., 2:] / 2 | |||
| return torch.cat([a, b], locations.dim() - 1) | |||
| def iou_of(boxes0, boxes1, eps=1e-5): | |||
| """Return intersection-over-union (Jaccard index) of boxes. | |||
| Args: | |||
| boxes0 (N, 4): ground truth boxes. | |||
| boxes1 (N or 1, 4): predicted boxes. | |||
| eps: a small number to avoid 0 as denominator. | |||
| Returns: | |||
| iou (N): IoU values. | |||
| """ | |||
| overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) | |||
| overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:]) | |||
| overlap_area = area_of(overlap_left_top, overlap_right_bottom) | |||
| area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) | |||
| area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) | |||
| return overlap_area / (area0 + area1 - overlap_area + eps) | |||
| def area_of(left_top, right_bottom) -> torch.Tensor: | |||
| """Compute the areas of rectangles given two corners. | |||
| Args: | |||
| left_top (N, 2): left top corner. | |||
| right_bottom (N, 2): right bottom corner. | |||
| Returns: | |||
| area (N): return the area. | |||
| """ | |||
| hw = torch.clamp(right_bottom - left_top, min=0.0) | |||
| return hw[..., 0] * hw[..., 1] | |||
| @@ -0,0 +1,49 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Mb_Tiny(nn.Module): | |||
| def __init__(self, num_classes=2): | |||
| super(Mb_Tiny, self).__init__() | |||
| self.base_channel = 8 * 2 | |||
| def conv_bn(inp, oup, stride): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
| nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) | |||
| def conv_dw(inp, oup, stride): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||
| nn.BatchNorm2d(inp), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(oup), | |||
| nn.ReLU(inplace=True), | |||
| ) | |||
| self.model = nn.Sequential( | |||
| conv_bn(3, self.base_channel, 2), # 160*120 | |||
| conv_dw(self.base_channel, self.base_channel * 2, 1), | |||
| conv_dw(self.base_channel * 2, self.base_channel * 2, 2), # 80*60 | |||
| conv_dw(self.base_channel * 2, self.base_channel * 2, 1), | |||
| conv_dw(self.base_channel * 2, self.base_channel * 4, 2), # 40*30 | |||
| conv_dw(self.base_channel * 4, self.base_channel * 4, 1), | |||
| conv_dw(self.base_channel * 4, self.base_channel * 4, 1), | |||
| conv_dw(self.base_channel * 4, self.base_channel * 4, 1), | |||
| conv_dw(self.base_channel * 4, self.base_channel * 8, 2), # 20*15 | |||
| conv_dw(self.base_channel * 8, self.base_channel * 8, 1), | |||
| conv_dw(self.base_channel * 8, self.base_channel * 8, 1), | |||
| conv_dw(self.base_channel * 8, self.base_channel * 16, 2), # 10*8 | |||
| conv_dw(self.base_channel * 16, self.base_channel * 16, 1)) | |||
| self.fc = nn.Linear(1024, num_classes) | |||
| def forward(self, x): | |||
| x = self.model(x) | |||
| x = F.avg_pool2d(x, 7) | |||
| x = x.view(-1, 1024) | |||
| x = self.fc(x) | |||
| return x | |||
| @@ -0,0 +1,18 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| from ..transforms import Compose, Resize, SubtractMeans, ToTensor | |||
| class PredictionTransform: | |||
| def __init__(self, size, mean=0.0, std=1.0): | |||
| self.transform = Compose([ | |||
| Resize(size), | |||
| SubtractMeans(mean), lambda img, boxes=None, labels=None: | |||
| (img / std, boxes, labels), | |||
| ToTensor() | |||
| ]) | |||
| def __call__(self, image): | |||
| image, _, _ = self.transform(image) | |||
| return image | |||
| @@ -0,0 +1,49 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import numpy as np | |||
| from ..box_utils import generate_priors | |||
| image_mean_test = image_mean = np.array([127, 127, 127]) | |||
| image_std = 128.0 | |||
| iou_threshold = 0.3 | |||
| center_variance = 0.1 | |||
| size_variance = 0.2 | |||
| min_boxes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]] | |||
| shrinkage_list = [] | |||
| image_size = [320, 240] # default input size 320*240 | |||
| feature_map_w_h_list = [[40, 20, 10, 5], [30, 15, 8, | |||
| 4]] # default feature map size | |||
| priors = [] | |||
| def define_img_size(size): | |||
| global image_size, feature_map_w_h_list, priors | |||
| img_size_dict = { | |||
| 128: [128, 96], | |||
| 160: [160, 120], | |||
| 320: [320, 240], | |||
| 480: [480, 360], | |||
| 640: [640, 480], | |||
| 1280: [1280, 960] | |||
| } | |||
| image_size = img_size_dict[size] | |||
| feature_map_w_h_list_dict = { | |||
| 128: [[16, 8, 4, 2], [12, 6, 3, 2]], | |||
| 160: [[20, 10, 5, 3], [15, 8, 4, 2]], | |||
| 320: [[40, 20, 10, 5], [30, 15, 8, 4]], | |||
| 480: [[60, 30, 15, 8], [45, 23, 12, 6]], | |||
| 640: [[80, 40, 20, 10], [60, 30, 15, 8]], | |||
| 1280: [[160, 80, 40, 20], [120, 60, 30, 15]] | |||
| } | |||
| feature_map_w_h_list = feature_map_w_h_list_dict[size] | |||
| for i in range(0, len(image_size)): | |||
| item_list = [] | |||
| for k in range(0, len(feature_map_w_h_list[i])): | |||
| item_list.append(image_size[i] / feature_map_w_h_list[i][k]) | |||
| shrinkage_list.append(item_list) | |||
| priors = generate_priors(feature_map_w_h_list, shrinkage_list, image_size, | |||
| min_boxes) | |||
| @@ -0,0 +1,124 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| from torch.nn import Conv2d, ModuleList, ReLU, Sequential | |||
| from ..mb_tiny import Mb_Tiny | |||
| from . import fd_config as config | |||
| from .predictor import Predictor | |||
| from .ssd import SSD | |||
| def SeperableConv2d(in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0): | |||
| """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. | |||
| """ | |||
| return Sequential( | |||
| Conv2d( | |||
| in_channels=in_channels, | |||
| out_channels=in_channels, | |||
| kernel_size=kernel_size, | |||
| groups=in_channels, | |||
| stride=stride, | |||
| padding=padding), | |||
| ReLU(), | |||
| Conv2d( | |||
| in_channels=in_channels, out_channels=out_channels, kernel_size=1), | |||
| ) | |||
| def create_mb_tiny_fd(num_classes, is_test=False, device='cuda'): | |||
| base_net = Mb_Tiny(2) | |||
| base_net_model = base_net.model # disable dropout layer | |||
| source_layer_indexes = [8, 11, 13] | |||
| extras = ModuleList([ | |||
| Sequential( | |||
| Conv2d( | |||
| in_channels=base_net.base_channel * 16, | |||
| out_channels=base_net.base_channel * 4, | |||
| kernel_size=1), ReLU(), | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 4, | |||
| out_channels=base_net.base_channel * 16, | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1), ReLU()) | |||
| ]) | |||
| regression_headers = ModuleList([ | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 4, | |||
| out_channels=3 * 4, | |||
| kernel_size=3, | |||
| padding=1), | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 8, | |||
| out_channels=2 * 4, | |||
| kernel_size=3, | |||
| padding=1), | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 16, | |||
| out_channels=2 * 4, | |||
| kernel_size=3, | |||
| padding=1), | |||
| Conv2d( | |||
| in_channels=base_net.base_channel * 16, | |||
| out_channels=3 * 4, | |||
| kernel_size=3, | |||
| padding=1) | |||
| ]) | |||
| classification_headers = ModuleList([ | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 4, | |||
| out_channels=3 * num_classes, | |||
| kernel_size=3, | |||
| padding=1), | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 8, | |||
| out_channels=2 * num_classes, | |||
| kernel_size=3, | |||
| padding=1), | |||
| SeperableConv2d( | |||
| in_channels=base_net.base_channel * 16, | |||
| out_channels=2 * num_classes, | |||
| kernel_size=3, | |||
| padding=1), | |||
| Conv2d( | |||
| in_channels=base_net.base_channel * 16, | |||
| out_channels=3 * num_classes, | |||
| kernel_size=3, | |||
| padding=1) | |||
| ]) | |||
| return SSD( | |||
| num_classes, | |||
| base_net_model, | |||
| source_layer_indexes, | |||
| extras, | |||
| classification_headers, | |||
| regression_headers, | |||
| is_test=is_test, | |||
| config=config, | |||
| device=device) | |||
| def create_mb_tiny_fd_predictor(net, | |||
| candidate_size=200, | |||
| nms_method=None, | |||
| sigma=0.5, | |||
| device=None): | |||
| predictor = Predictor( | |||
| net, | |||
| config.image_size, | |||
| config.image_mean_test, | |||
| config.image_std, | |||
| nms_method=nms_method, | |||
| iou_threshold=config.iou_threshold, | |||
| candidate_size=candidate_size, | |||
| sigma=sigma, | |||
| device=device) | |||
| return predictor | |||
| @@ -0,0 +1,80 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import torch | |||
| from .. import box_utils | |||
| from .data_preprocessing import PredictionTransform | |||
| class Predictor: | |||
| def __init__(self, | |||
| net, | |||
| size, | |||
| mean=0.0, | |||
| std=1.0, | |||
| nms_method=None, | |||
| iou_threshold=0.3, | |||
| filter_threshold=0.85, | |||
| candidate_size=200, | |||
| sigma=0.5, | |||
| device=None): | |||
| self.net = net | |||
| self.transform = PredictionTransform(size, mean, std) | |||
| self.iou_threshold = iou_threshold | |||
| self.filter_threshold = filter_threshold | |||
| self.candidate_size = candidate_size | |||
| self.nms_method = nms_method | |||
| self.sigma = sigma | |||
| if device: | |||
| self.device = device | |||
| else: | |||
| self.device = torch.device( | |||
| 'cuda:0' if torch.cuda.is_available() else 'cpu') | |||
| self.net.to(self.device) | |||
| self.net.eval() | |||
| def predict(self, image, top_k=-1, prob_threshold=None): | |||
| height, width, _ = image.shape | |||
| image = self.transform(image) | |||
| images = image.unsqueeze(0) | |||
| images = images.to(self.device) | |||
| with torch.no_grad(): | |||
| for i in range(1): | |||
| scores, boxes = self.net.forward(images) | |||
| boxes = boxes[0] | |||
| scores = scores[0] | |||
| if not prob_threshold: | |||
| prob_threshold = self.filter_threshold | |||
| # this version of nms is slower on GPU, so we move data to CPU. | |||
| picked_box_probs = [] | |||
| picked_labels = [] | |||
| for class_index in range(1, scores.size(1)): | |||
| probs = scores[:, class_index] | |||
| mask = probs > prob_threshold | |||
| probs = probs[mask] | |||
| if probs.size(0) == 0: | |||
| continue | |||
| subset_boxes = boxes[mask, :] | |||
| box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) | |||
| box_probs = box_utils.nms( | |||
| box_probs, | |||
| self.nms_method, | |||
| score_threshold=prob_threshold, | |||
| iou_threshold=self.iou_threshold, | |||
| sigma=self.sigma, | |||
| top_k=top_k, | |||
| candidate_size=self.candidate_size) | |||
| picked_box_probs.append(box_probs) | |||
| picked_labels.extend([class_index] * box_probs.size(0)) | |||
| if not picked_box_probs: | |||
| return torch.tensor([]), torch.tensor([]), torch.tensor([]) | |||
| picked_box_probs = torch.cat(picked_box_probs) | |||
| picked_box_probs[:, 0] *= width | |||
| picked_box_probs[:, 1] *= height | |||
| picked_box_probs[:, 2] *= width | |||
| picked_box_probs[:, 3] *= height | |||
| return picked_box_probs[:, :4], torch.tensor( | |||
| picked_labels), picked_box_probs[:, 4] | |||
| @@ -0,0 +1,129 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| from collections import namedtuple | |||
| from typing import List, Tuple | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from .. import box_utils | |||
| GraphPath = namedtuple('GraphPath', ['s0', 'name', 's1']) | |||
| class SSD(nn.Module): | |||
| def __init__(self, | |||
| num_classes: int, | |||
| base_net: nn.ModuleList, | |||
| source_layer_indexes: List[int], | |||
| extras: nn.ModuleList, | |||
| classification_headers: nn.ModuleList, | |||
| regression_headers: nn.ModuleList, | |||
| is_test=False, | |||
| config=None, | |||
| device=None): | |||
| """Compose a SSD model using the given components. | |||
| """ | |||
| super(SSD, self).__init__() | |||
| self.num_classes = num_classes | |||
| self.base_net = base_net | |||
| self.source_layer_indexes = source_layer_indexes | |||
| self.extras = extras | |||
| self.classification_headers = classification_headers | |||
| self.regression_headers = regression_headers | |||
| self.is_test = is_test | |||
| self.config = config | |||
| # register layers in source_layer_indexes by adding them to a module list | |||
| self.source_layer_add_ons = nn.ModuleList([ | |||
| t[1] for t in source_layer_indexes | |||
| if isinstance(t, tuple) and not isinstance(t, GraphPath) | |||
| ]) | |||
| if device: | |||
| self.device = device | |||
| else: | |||
| self.device = torch.device( | |||
| 'cuda:0' if torch.cuda.is_available() else 'cpu') | |||
| if is_test: | |||
| self.config = config | |||
| self.priors = config.priors.to(self.device) | |||
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| confidences = [] | |||
| locations = [] | |||
| start_layer_index = 0 | |||
| header_index = 0 | |||
| end_layer_index = 0 | |||
| for end_layer_index in self.source_layer_indexes: | |||
| if isinstance(end_layer_index, GraphPath): | |||
| path = end_layer_index | |||
| end_layer_index = end_layer_index.s0 | |||
| added_layer = None | |||
| elif isinstance(end_layer_index, tuple): | |||
| added_layer = end_layer_index[1] | |||
| end_layer_index = end_layer_index[0] | |||
| path = None | |||
| else: | |||
| added_layer = None | |||
| path = None | |||
| for layer in self.base_net[start_layer_index:end_layer_index]: | |||
| x = layer(x) | |||
| if added_layer: | |||
| y = added_layer(x) | |||
| else: | |||
| y = x | |||
| if path: | |||
| sub = getattr(self.base_net[end_layer_index], path.name) | |||
| for layer in sub[:path.s1]: | |||
| x = layer(x) | |||
| y = x | |||
| for layer in sub[path.s1:]: | |||
| x = layer(x) | |||
| end_layer_index += 1 | |||
| start_layer_index = end_layer_index | |||
| confidence, location = self.compute_header(header_index, y) | |||
| header_index += 1 | |||
| confidences.append(confidence) | |||
| locations.append(location) | |||
| for layer in self.base_net[end_layer_index:]: | |||
| x = layer(x) | |||
| for layer in self.extras: | |||
| x = layer(x) | |||
| confidence, location = self.compute_header(header_index, x) | |||
| header_index += 1 | |||
| confidences.append(confidence) | |||
| locations.append(location) | |||
| confidences = torch.cat(confidences, 1) | |||
| locations = torch.cat(locations, 1) | |||
| if self.is_test: | |||
| confidences = F.softmax(confidences, dim=2) | |||
| boxes = box_utils.convert_locations_to_boxes( | |||
| locations, self.priors, self.config.center_variance, | |||
| self.config.size_variance) | |||
| boxes = box_utils.center_form_to_corner_form(boxes) | |||
| return confidences, boxes | |||
| else: | |||
| return confidences, locations | |||
| def compute_header(self, i, x): | |||
| confidence = self.classification_headers[i](x) | |||
| confidence = confidence.permute(0, 2, 3, 1).contiguous() | |||
| confidence = confidence.view(confidence.size(0), -1, self.num_classes) | |||
| location = self.regression_headers[i](x) | |||
| location = location.permute(0, 2, 3, 1).contiguous() | |||
| location = location.view(location.size(0), -1, 4) | |||
| return confidence, location | |||
| def load(self, model): | |||
| self.load_state_dict( | |||
| torch.load(model, map_location=lambda storage, loc: storage)) | |||
| @@ -0,0 +1,56 @@ | |||
| # The implementation is based on ULFD, available at | |||
| # https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB | |||
| import types | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from numpy import random | |||
| class Compose(object): | |||
| """Composes several augmentations together. | |||
| Args: | |||
| transforms (List[Transform]): list of transforms to compose. | |||
| Example: | |||
| >>> augmentations.Compose([ | |||
| >>> transforms.CenterCrop(10), | |||
| >>> transforms.ToTensor(), | |||
| >>> ]) | |||
| """ | |||
| def __init__(self, transforms): | |||
| self.transforms = transforms | |||
| def __call__(self, img, boxes=None, labels=None): | |||
| for t in self.transforms: | |||
| img, boxes, labels = t(img, boxes, labels) | |||
| return img, boxes, labels | |||
| class SubtractMeans(object): | |||
| def __init__(self, mean): | |||
| self.mean = np.array(mean, dtype=np.float32) | |||
| def __call__(self, image, boxes=None, labels=None): | |||
| image = image.astype(np.float32) | |||
| image -= self.mean | |||
| return image.astype(np.float32), boxes, labels | |||
| class Resize(object): | |||
| def __init__(self, size=(300, 300)): | |||
| self.size = size | |||
| def __call__(self, image, boxes=None, labels=None): | |||
| image = cv2.resize(image, (self.size[0], self.size[1])) | |||
| return image, boxes, labels | |||
| class ToTensor(object): | |||
| def __call__(self, cvimage, boxes=None, labels=None): | |||
| return torch.from_numpy(cvimage.astype(np.float32)).permute( | |||
| 2, 0, 1), boxes, labels | |||
| @@ -46,8 +46,9 @@ if TYPE_CHECKING: | |||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
| from .shop_segmentation_pipleline import ShopSegmentationPipeline | |||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | |||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipleline | |||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | |||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | |||
| from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | |||
| from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | |||
| from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | |||
| @@ -110,6 +111,7 @@ else: | |||
| ['TextDrivenSegmentationPipeline'], | |||
| 'movie_scene_segmentation_pipeline': | |||
| ['MovieSceneSegmentationPipeline'], | |||
| 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | |||
| 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | |||
| 'facial_expression_recognition_pipelin': | |||
| ['FacialExpressionRecognitionPipeline'] | |||
| @@ -0,0 +1,56 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_detection import UlfdFaceDetector | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_detection, module_name=Pipelines.ulfd_face_detection) | |||
| class UlfdFaceDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a face detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| detector = UlfdFaceDetector(model_path=ckpt_path, device=self.device) | |||
| self.detector = detector | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img.astype(np.float32) | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = self.detector(input) | |||
| assert result is not None | |||
| bboxes = result[0].tolist() | |||
| scores = result[1].tolist() | |||
| return { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.BOXES: bboxes, | |||
| OutputKeys.KEYPOINTS: None, | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -89,6 +89,27 @@ def draw_keypoints(output, original_image): | |||
| return image | |||
| def draw_face_detection_no_lm_result(img_path, detection_result): | |||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
| scores = np.array(detection_result[OutputKeys.SCORES]) | |||
| img = cv2.imread(img_path) | |||
| assert img is not None, f"Can't read img: {img_path}" | |||
| for i in range(len(scores)): | |||
| bbox = bboxes[i].astype(np.int32) | |||
| x1, y1, x2, y2 = bbox | |||
| score = scores[i] | |||
| cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) | |||
| cv2.putText( | |||
| img, | |||
| f'{score:.2f}', (x1, y2), | |||
| 1, | |||
| 1.0, (0, 255, 0), | |||
| thickness=1, | |||
| lineType=8) | |||
| print(f'Found {len(scores)} faces') | |||
| return img | |||
| def draw_facial_expression_result(img_path, facial_expression_result): | |||
| label_idx = facial_expression_result[OutputKeys.LABELS] | |||
| map_list = [ | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result | |||
| from modelscope.utils.test_utils import test_level | |||
| class UlfdFaceDetectionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_manual_face-detection_ulfd' | |||
| def show_result(self, img_path, detection_result): | |||
| img = draw_face_detection_no_lm_result(img_path, detection_result) | |||
| cv2.imwrite('result.png', img) | |||
| print(f'output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| img_path = 'data/test/images/ulfd_face_detection.jpg' | |||
| result = face_detection(img_path) | |||
| self.show_result(img_path, result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||