diff --git a/data/test/images/mog_face_detection.jpg b/data/test/images/mog_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/mog_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d7594794..270c5aaf 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -35,6 +35,7 @@ class Models(object): fer = 'fer' retinaface = 'retinaface' shop_segmentation = 'shop-segmentation' + mogface = 'mogface' mtcnn = 'mtcnn' ulfd = 'ulfd' @@ -128,6 +129,7 @@ class Pipelines(object): ulfd_face_detection = 'manual-face-detection-ulfd' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' retina_face_detection = 'resnet50-face-detection-retinaface' + mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' mtcnn_face_detection = 'manual-face-detection-mtcnn' live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py index ed8832c2..a2a845d2 100644 --- a/modelscope/models/cv/face_detection/__init__.py +++ b/modelscope/models/cv/face_detection/__init__.py @@ -4,15 +4,16 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .mogface import MogFaceDetector from .mtcnn import MtcnnFaceDetector from .retinaface import RetinaFaceDetection from .ulfd_slim import UlfdFaceDetector - else: _import_structure = { 'ulfd_slim': ['UlfdFaceDetector'], 'retinaface': ['RetinaFaceDetection'], - 'mtcnn': ['MtcnnFaceDetector'] + 'mtcnn': ['MtcnnFaceDetector'], + 'mogface': ['MogFaceDetector'] } import sys diff --git a/modelscope/models/cv/face_detection/mogface/__init__.py b/modelscope/models/cv/face_detection/mogface/__init__.py new file mode 100644 index 00000000..8190b649 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/__init__.py @@ -0,0 +1 @@ +from .models.detectors import MogFaceDetector diff --git a/modelscope/models/cv/face_detection/mogface/models/__init__.py b/modelscope/models/cv/face_detection/mogface/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/mogface/models/detectors.py b/modelscope/models/cv/face_detection/mogface/models/detectors.py new file mode 100644 index 00000000..5ae67104 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/detectors.py @@ -0,0 +1,96 @@ +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .mogface import MogFace +from .utils import MogPriorBox, mogdecode, py_cpu_nms + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.mogface) +class MogFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.net = MogFace() + self.load_model() + self.net = self.net.to(device) + + self.mean = np.array([[104, 117, 123]]) + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def forward(self, input): + img_raw = input['img'] + img = np.array(img_raw.cpu().detach()) + img = img[:, :, ::-1] + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= np.array([[103.53, 116.28, 123.675]]) + img /= np.array([[57.375, 57.120003, 58.395]]) + img /= 255 + img = img[:, :, ::-1].copy() + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + conf, loc = self.net(img) # forward pass + + confidence_threshold = 0.82 + nms_threshold = 0.4 + top_k = 5000 + keep_top_k = 750 + + priorbox = MogPriorBox(scale_list=[0.68]) + priors = priorbox(im_height, im_width) + priors = torch.tensor(priors).to(self.device) + prior_data = priors.data + + boxes = mogdecode(loc.data.squeeze(0), prior_data) + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 0] + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + + return dets / ss diff --git a/modelscope/models/cv/face_detection/mogface/models/mogface.py b/modelscope/models/cv/face_detection/mogface/models/mogface.py new file mode 100644 index 00000000..294c2c6b --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/mogface.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------- +# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on +# https://github.com/damo-cv/MogFace +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F + +from .mogprednet import MogPredNet +from .resnet import ResNet + + +class MogFace(nn.Module): + + def __init__(self): + super(MogFace, self).__init__() + self.backbone = ResNet(depth=101) + self.fpn = LFPN() + self.pred_net = MogPredNet() + + def forward(self, x): + feature_list = self.backbone(x) + fpn_list = self.fpn(feature_list) + pyramid_feature_list = fpn_list[0] + conf, loc = self.pred_net(pyramid_feature_list) + return conf, loc + + +class FeatureFusion(nn.Module): + + def __init__(self, lat_ch=256, **channels): + super(FeatureFusion, self).__init__() + self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1) + + def forward(self, up, main): + main = self.main_conv(main) + _, _, H, W = main.size() + res = F.upsample(up, scale_factor=2, mode='bilinear') + if res.size(2) != main.size(2) or res.size(3) != main.size(3): + res = res[:, :, 0:H, 0:W] + res = res + main + return res + + +class LFPN(nn.Module): + + def __init__(self, + c2_out_ch=256, + c3_out_ch=512, + c4_out_ch=1024, + c5_out_ch=2048, + c6_mid_ch=512, + c6_out_ch=512, + c7_mid_ch=128, + c7_out_ch=256, + out_dsfd_ft=True): + super(LFPN, self).__init__() + self.out_dsfd_ft = out_dsfd_ft + if self.out_dsfd_ft: + dsfd_module = [] + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + self.dsfd_modules = nn.ModuleList(dsfd_module) + + c6_input_ch = c5_out_ch + self.c6 = nn.Sequential(*[ + nn.Conv2d( + c6_input_ch, + c6_mid_ch, + kernel_size=1, + ), + nn.BatchNorm2d(c6_mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d( + c6_mid_ch, c6_out_ch, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c6_out_ch), + nn.ReLU(inplace=True) + ]) + self.c7 = nn.Sequential(*[ + nn.Conv2d( + c6_out_ch, + c7_mid_ch, + kernel_size=1, + ), + nn.BatchNorm2d(c7_mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d( + c7_mid_ch, c7_out_ch, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c7_out_ch), + nn.ReLU(inplace=True) + ]) + + self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + + self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1) + self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1) + self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1) + + self.ff_c5_c4 = FeatureFusion(main=c4_out_ch) + self.ff_c4_c3 = FeatureFusion(main=c3_out_ch) + self.ff_c3_c2 = FeatureFusion(main=c2_out_ch) + + def forward(self, feature_list): + c2, c3, c4, c5 = feature_list + c6 = self.c6(c5) + c7 = self.c7(c6) + + c5 = self.c5_lat(c5) + c6 = self.c6_lat(c6) + c7 = self.c7_lat(c7) + + if self.out_dsfd_ft: + dsfd_fts = [] + dsfd_fts.append(self.dsfd_modules[0](c2)) + dsfd_fts.append(self.dsfd_modules[1](c3)) + dsfd_fts.append(self.dsfd_modules[2](c4)) + dsfd_fts.append(self.dsfd_modules[3](feature_list[-1])) + dsfd_fts.append(self.dsfd_modules[4](c6)) + dsfd_fts.append(self.dsfd_modules[5](c7)) + + p4 = self.ff_c5_c4(c5, c4) + p3 = self.ff_c4_c3(p4, c3) + p2 = self.ff_c3_c2(p3, c2) + + p2 = self.p2_lat(p2) + p3 = self.p3_lat(p3) + p4 = self.p4_lat(p4) + + if self.out_dsfd_ft: + return ([p2, p3, p4, c5, c6, c7], dsfd_fts) diff --git a/modelscope/models/cv/face_detection/mogface/models/mogprednet.py b/modelscope/models/cv/face_detection/mogface/models/mogprednet.py new file mode 100644 index 00000000..31384976 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/mogprednet.py @@ -0,0 +1,164 @@ +# -------------------------------------------------------- +# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on +# https://github.com/damo-cv/MogFace +# -------------------------------------------------------- +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class conv_bn(nn.Module): + """docstring for conv""" + + def __init__(self, in_plane, out_plane, kernel_size, stride, padding): + super(conv_bn, self).__init__() + self.conv1 = nn.Conv2d( + in_plane, + out_plane, + kernel_size=kernel_size, + stride=stride, + padding=padding) + self.bn1 = nn.BatchNorm2d(out_plane) + + def forward(self, x): + x = self.conv1(x) + return self.bn1(x) + + +class SSHContext(nn.Module): + + def __init__(self, channels, Xchannels=256): + super(SSHContext, self).__init__() + + self.conv1 = nn.Conv2d( + channels, Xchannels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d( + channels, + Xchannels // 2, + kernel_size=3, + dilation=2, + stride=1, + padding=2) + self.conv2_1 = nn.Conv2d( + Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d( + Xchannels // 2, + Xchannels // 2, + kernel_size=3, + dilation=2, + stride=1, + padding=2) + self.conv2_2_1 = nn.Conv2d( + Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x1 = F.relu(self.conv1(x), inplace=True) + x2 = F.relu(self.conv2(x), inplace=True) + x2_1 = F.relu(self.conv2_1(x2), inplace=True) + x2_2 = F.relu(self.conv2_2(x2), inplace=True) + x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True) + + return torch.cat([x1, x2_1, x2_2], 1) + + +class DeepHead(nn.Module): + + def __init__(self, + in_channel=256, + out_channel=256, + use_gn=False, + num_conv=4): + super(DeepHead, self).__init__() + self.use_gn = use_gn + self.num_conv = num_conv + self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1) + self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + if self.use_gn: + self.gn1 = nn.GroupNorm(16, out_channel) + self.gn2 = nn.GroupNorm(16, out_channel) + self.gn3 = nn.GroupNorm(16, out_channel) + self.gn4 = nn.GroupNorm(16, out_channel) + + def forward(self, x): + if self.use_gn: + x1 = F.relu(self.gn1(self.conv1(x)), inplace=True) + x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True) + x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True) + x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True) + else: + x1 = F.relu(self.conv1(x), inplace=True) + x2 = F.relu(self.conv1(x1), inplace=True) + if self.num_conv == 2: + return x2 + x3 = F.relu(self.conv1(x2), inplace=True) + x4 = F.relu(self.conv1(x3), inplace=True) + + return x4 + + +class MogPredNet(nn.Module): + + def __init__(self, + num_anchor_per_pixel=1, + num_classes=1, + input_ch_list=[256, 256, 256, 256, 256, 256], + use_deep_head=True, + deep_head_with_gn=True, + use_ssh=True, + deep_head_ch=512): + super(MogPredNet, self).__init__() + self.num_classes = num_classes + self.use_deep_head = use_deep_head + self.deep_head_with_gn = deep_head_with_gn + + self.use_ssh = use_ssh + + self.deep_head_ch = deep_head_ch + + if self.use_ssh: + self.conv_SSH = SSHContext(input_ch_list[0], + self.deep_head_ch // 2) + + if self.use_deep_head: + if self.deep_head_with_gn: + self.deep_loc_head = DeepHead( + self.deep_head_ch, self.deep_head_ch, use_gn=True) + self.deep_cls_head = DeepHead( + self.deep_head_ch, self.deep_head_ch, use_gn=True) + + self.pred_cls = nn.Conv2d(self.deep_head_ch, + 1 * num_anchor_per_pixel, 3, 1, 1) + self.pred_loc = nn.Conv2d(self.deep_head_ch, + 4 * num_anchor_per_pixel, 3, 1, 1) + + self.sigmoid = nn.Sigmoid() + + def forward(self, pyramid_feature_list, dsfd_ft_list=None): + loc = [] + conf = [] + + if self.use_deep_head: + for x in pyramid_feature_list: + if self.use_ssh: + x = self.conv_SSH(x) + x_cls = self.deep_cls_head(x) + x_loc = self.deep_loc_head(x) + + conf.append( + self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous()) + loc.append( + self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous()) + + loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) + conf = torch.cat( + [o.view(o.size(0), -1, self.num_classes) for o in conf], 1) + output = ( + self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)), + loc.view(loc.size(0), -1, 4), + ) + + return output diff --git a/modelscope/models/cv/face_detection/mogface/models/resnet.py b/modelscope/models/cv/face_detection/mogface/models/resnet.py new file mode 100644 index 00000000..045f6fa3 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/resnet.py @@ -0,0 +1,193 @@ +# The implementation is modified from original resent implementaiton, which is +# also open-sourced by the authors as Yang Liu, +# and is available publicly on https://github.com/damo-cv/MogFace + +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + depth=50, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + inplanes=64, + shrink_ch_ratio=1): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + if depth == 50: + block = Bottleneck + layers = [3, 4, 6, 3] + elif depth == 101: + block = Bottleneck + layers = [3, 4, 23, 3] + elif depth == 152: + block = Bottleneck + layers = [3, 4, 36, 3] + elif depth == 18: + block = BasicBlock + layers = [2, 2, 2, 2] + else: + raise ValueError('only support depth in [18, 50, 101, 152]') + + shrink_input_ch = int(inplanes * shrink_ch_ratio) + self.inplanes = int(inplanes * shrink_ch_ratio) + if shrink_ch_ratio == 0.125: + layers = [2, 3, 3, 3] + + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, shrink_input_ch, layers[0]) + self.layer2 = self._make_layer( + block, + shrink_input_ch * 2, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + shrink_input_ch * 4, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + shrink_input_ch * 8, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + four_conv_layer = [] + x = self.layer1(x) + four_conv_layer.append(x) + x = self.layer2(x) + four_conv_layer.append(x) + x = self.layer3(x) + four_conv_layer.append(x) + x = self.layer4(x) + four_conv_layer.append(x) + + return four_conv_layer diff --git a/modelscope/models/cv/face_detection/mogface/models/utils.py b/modelscope/models/cv/face_detection/mogface/models/utils.py new file mode 100755 index 00000000..377ceb3d --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/utils.py @@ -0,0 +1,212 @@ +# Modified from https://github.com/biubug6/Pytorch_Retinaface + +import math +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +def transform_anchor(anchors): + """ + from [x0, x1, y0, y1] to [c_x, cy, w, h] + x1 = x0 + w - 1 + c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2 + """ + return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2, + anchors[:, 2:] - anchors[:, :2] + 1), + axis=1) + + +def normalize_anchor(anchors): + """ + from [c_x, cy, w, h] to [x0, x1, y0, y1] + """ + item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2 + item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2 + return np.concatenate((item_1, item_2), axis=1) + + +class MogPriorBox(object): + """ + both for fpn and single layer, single layer need to test + return (np.array) [num_anchros, 4] [x0, y0, x1, y1] + """ + + def __init__(self, + scale_list=[1.], + aspect_ratio_list=[1.0], + stride_list=[4, 8, 16, 32, 64, 128], + anchor_size_list=[16, 32, 64, 128, 256, 512]): + self.scale_list = scale_list + self.aspect_ratio_list = aspect_ratio_list + self.stride_list = stride_list + self.anchor_size_list = anchor_size_list + + def __call__(self, img_height, img_width): + final_anchor_list = [] + + for idx, stride in enumerate(self.stride_list): + anchor_list = [] + cur_img_height = img_height + cur_img_width = img_width + tmp_stride = stride + + while tmp_stride != 1: + tmp_stride = tmp_stride // 2 + cur_img_height = (cur_img_height + 1) // 2 + cur_img_width = (cur_img_width + 1) // 2 + + for i in range(cur_img_height): + for j in range(cur_img_width): + for scale in self.scale_list: + cx = (j + 0.5) * stride + cy = (i + 0.5) * stride + side_x = self.anchor_size_list[idx] * scale + side_y = self.anchor_size_list[idx] * scale + for ratio in self.aspect_ratio_list: + anchor_list.append([ + cx, cy, side_x / math.sqrt(ratio), + side_y * math.sqrt(ratio) + ]) + + final_anchor_list.append(anchor_list) + final_anchor_arr = np.concatenate(final_anchor_list, axis=0) + normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype( + 'float32') + transformed_anchor = transform_anchor(normalized_anchor_arr) + + return transformed_anchor + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def mogdecode(loc, anchors): + """ + loc: torch.Tensor + anchors: 2-d, torch.Tensor (cx, cy, w, h) + boxes: 2-d, torch.Tensor (x0, y0, x1, y1) + """ + + boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:], + anchors[:, 2:] * torch.exp(loc[:, 2:])), 1) + + boxes[:, 0] -= (boxes[:, 2] - 1) / 2 + boxes[:, 1] -= (boxes[:, 3] - 1) / 2 + boxes[:, 2] += boxes[:, 0] - 1 + boxes[:, 3] += boxes[:, 1] - 1 + + return boxes + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 3eb5cd82..a9dc05f2 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline + from .mog_face_detection_pipeline import MogFaceDetectionPipeline from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline @@ -112,6 +113,7 @@ else: ['TextDrivenSegmentationPipeline'], 'movie_scene_segmentation_pipeline': ['MovieSceneSegmentationPipeline'], + 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], 'facial_expression_recognition_pipelin': diff --git a/modelscope/pipelines/cv/mog_face_detection_pipeline.py b/modelscope/pipelines/cv/mog_face_detection_pipeline.py new file mode 100644 index 00000000..8797ad12 --- /dev/null +++ b/modelscope/pipelines/cv/mog_face_detection_pipeline.py @@ -0,0 +1,54 @@ +import os.path as osp +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import MogFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.mog_face_detection) +class MogFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = MogFaceDetector(model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + result = self.detector(input) + assert result is not None + bboxes = result[:, :4].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: None, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_mog_face_detection.py b/tests/pipelines/test_mog_face_detection.py new file mode 100644 index 00000000..5c6d97c2 --- /dev/null +++ b/tests/pipelines/test_mog_face_detection.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result +from modelscope.utils.test_utils import test_level + + +class MogFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_no_lm_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/mog_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main()