Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9921926master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 | |||||
| size 87228 | |||||
| @@ -35,6 +35,7 @@ class Models(object): | |||||
| fer = 'fer' | fer = 'fer' | ||||
| retinaface = 'retinaface' | retinaface = 'retinaface' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| mogface = 'mogface' | |||||
| mtcnn = 'mtcnn' | mtcnn = 'mtcnn' | ||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| @@ -128,6 +129,7 @@ class Pipelines(object): | |||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | |||||
| mtcnn_face_detection = 'manual-face-detection-mtcnn' | mtcnn_face_detection = 'manual-face-detection-mtcnn' | ||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| @@ -4,15 +4,16 @@ from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .mogface import MogFaceDetector | |||||
| from .mtcnn import MtcnnFaceDetector | from .mtcnn import MtcnnFaceDetector | ||||
| from .retinaface import RetinaFaceDetection | from .retinaface import RetinaFaceDetection | ||||
| from .ulfd_slim import UlfdFaceDetector | from .ulfd_slim import UlfdFaceDetector | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'ulfd_slim': ['UlfdFaceDetector'], | 'ulfd_slim': ['UlfdFaceDetector'], | ||||
| 'retinaface': ['RetinaFaceDetection'], | 'retinaface': ['RetinaFaceDetection'], | ||||
| 'mtcnn': ['MtcnnFaceDetector'] | |||||
| 'mtcnn': ['MtcnnFaceDetector'], | |||||
| 'mogface': ['MogFaceDetector'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1 @@ | |||||
| from .models.detectors import MogFaceDetector | |||||
| @@ -0,0 +1,96 @@ | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .mogface import MogFace | |||||
| from .utils import MogPriorBox, mogdecode, py_cpu_nms | |||||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.mogface) | |||||
| class MogFaceDetector(TorchModel): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| super().__init__(model_path) | |||||
| torch.set_grad_enabled(False) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.device = device | |||||
| self.net = MogFace() | |||||
| self.load_model() | |||||
| self.net = self.net.to(device) | |||||
| self.mean = np.array([[104, 117, 123]]) | |||||
| def load_model(self, load_to_cpu=False): | |||||
| pretrained_dict = torch.load( | |||||
| self.model_path, map_location=torch.device('cpu')) | |||||
| self.net.load_state_dict(pretrained_dict, strict=False) | |||||
| self.net.eval() | |||||
| def forward(self, input): | |||||
| img_raw = input['img'] | |||||
| img = np.array(img_raw.cpu().detach()) | |||||
| img = img[:, :, ::-1] | |||||
| im_height, im_width = img.shape[:2] | |||||
| ss = 1.0 | |||||
| # tricky | |||||
| if max(im_height, im_width) > 1500: | |||||
| ss = 1000.0 / max(im_height, im_width) | |||||
| img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||||
| im_height, im_width = img.shape[:2] | |||||
| scale = torch.Tensor( | |||||
| [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||||
| img -= np.array([[103.53, 116.28, 123.675]]) | |||||
| img /= np.array([[57.375, 57.120003, 58.395]]) | |||||
| img /= 255 | |||||
| img = img[:, :, ::-1].copy() | |||||
| img = img.transpose(2, 0, 1) | |||||
| img = torch.from_numpy(img).unsqueeze(0) | |||||
| img = img.to(self.device) | |||||
| scale = scale.to(self.device) | |||||
| conf, loc = self.net(img) # forward pass | |||||
| confidence_threshold = 0.82 | |||||
| nms_threshold = 0.4 | |||||
| top_k = 5000 | |||||
| keep_top_k = 750 | |||||
| priorbox = MogPriorBox(scale_list=[0.68]) | |||||
| priors = priorbox(im_height, im_width) | |||||
| priors = torch.tensor(priors).to(self.device) | |||||
| prior_data = priors.data | |||||
| boxes = mogdecode(loc.data.squeeze(0), prior_data) | |||||
| boxes = boxes.cpu().numpy() | |||||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 0] | |||||
| # ignore low scores | |||||
| inds = np.where(scores > confidence_threshold)[0] | |||||
| boxes = boxes[inds] | |||||
| scores = scores[inds] | |||||
| # keep top-K before NMS | |||||
| order = scores.argsort()[::-1][:top_k] | |||||
| boxes = boxes[order] | |||||
| scores = scores[order] | |||||
| # do NMS | |||||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||||
| np.float32, copy=False) | |||||
| keep = py_cpu_nms(dets, nms_threshold) | |||||
| dets = dets[keep, :] | |||||
| # keep top-K faster NMS | |||||
| dets = dets[:keep_top_k, :] | |||||
| return dets / ss | |||||
| @@ -0,0 +1,135 @@ | |||||
| # -------------------------------------------------------- | |||||
| # The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on | |||||
| # https://github.com/damo-cv/MogFace | |||||
| # -------------------------------------------------------- | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .mogprednet import MogPredNet | |||||
| from .resnet import ResNet | |||||
| class MogFace(nn.Module): | |||||
| def __init__(self): | |||||
| super(MogFace, self).__init__() | |||||
| self.backbone = ResNet(depth=101) | |||||
| self.fpn = LFPN() | |||||
| self.pred_net = MogPredNet() | |||||
| def forward(self, x): | |||||
| feature_list = self.backbone(x) | |||||
| fpn_list = self.fpn(feature_list) | |||||
| pyramid_feature_list = fpn_list[0] | |||||
| conf, loc = self.pred_net(pyramid_feature_list) | |||||
| return conf, loc | |||||
| class FeatureFusion(nn.Module): | |||||
| def __init__(self, lat_ch=256, **channels): | |||||
| super(FeatureFusion, self).__init__() | |||||
| self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1) | |||||
| def forward(self, up, main): | |||||
| main = self.main_conv(main) | |||||
| _, _, H, W = main.size() | |||||
| res = F.upsample(up, scale_factor=2, mode='bilinear') | |||||
| if res.size(2) != main.size(2) or res.size(3) != main.size(3): | |||||
| res = res[:, :, 0:H, 0:W] | |||||
| res = res + main | |||||
| return res | |||||
| class LFPN(nn.Module): | |||||
| def __init__(self, | |||||
| c2_out_ch=256, | |||||
| c3_out_ch=512, | |||||
| c4_out_ch=1024, | |||||
| c5_out_ch=2048, | |||||
| c6_mid_ch=512, | |||||
| c6_out_ch=512, | |||||
| c7_mid_ch=128, | |||||
| c7_out_ch=256, | |||||
| out_dsfd_ft=True): | |||||
| super(LFPN, self).__init__() | |||||
| self.out_dsfd_ft = out_dsfd_ft | |||||
| if self.out_dsfd_ft: | |||||
| dsfd_module = [] | |||||
| dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||||
| dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1)) | |||||
| dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1)) | |||||
| dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1)) | |||||
| dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||||
| dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) | |||||
| self.dsfd_modules = nn.ModuleList(dsfd_module) | |||||
| c6_input_ch = c5_out_ch | |||||
| self.c6 = nn.Sequential(*[ | |||||
| nn.Conv2d( | |||||
| c6_input_ch, | |||||
| c6_mid_ch, | |||||
| kernel_size=1, | |||||
| ), | |||||
| nn.BatchNorm2d(c6_mid_ch), | |||||
| nn.ReLU(inplace=True), | |||||
| nn.Conv2d( | |||||
| c6_mid_ch, c6_out_ch, kernel_size=3, padding=1, stride=2), | |||||
| nn.BatchNorm2d(c6_out_ch), | |||||
| nn.ReLU(inplace=True) | |||||
| ]) | |||||
| self.c7 = nn.Sequential(*[ | |||||
| nn.Conv2d( | |||||
| c6_out_ch, | |||||
| c7_mid_ch, | |||||
| kernel_size=1, | |||||
| ), | |||||
| nn.BatchNorm2d(c7_mid_ch), | |||||
| nn.ReLU(inplace=True), | |||||
| nn.Conv2d( | |||||
| c7_mid_ch, c7_out_ch, kernel_size=3, padding=1, stride=2), | |||||
| nn.BatchNorm2d(c7_out_ch), | |||||
| nn.ReLU(inplace=True) | |||||
| ]) | |||||
| self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||||
| self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||||
| self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |||||
| self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1) | |||||
| self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1) | |||||
| self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1) | |||||
| self.ff_c5_c4 = FeatureFusion(main=c4_out_ch) | |||||
| self.ff_c4_c3 = FeatureFusion(main=c3_out_ch) | |||||
| self.ff_c3_c2 = FeatureFusion(main=c2_out_ch) | |||||
| def forward(self, feature_list): | |||||
| c2, c3, c4, c5 = feature_list | |||||
| c6 = self.c6(c5) | |||||
| c7 = self.c7(c6) | |||||
| c5 = self.c5_lat(c5) | |||||
| c6 = self.c6_lat(c6) | |||||
| c7 = self.c7_lat(c7) | |||||
| if self.out_dsfd_ft: | |||||
| dsfd_fts = [] | |||||
| dsfd_fts.append(self.dsfd_modules[0](c2)) | |||||
| dsfd_fts.append(self.dsfd_modules[1](c3)) | |||||
| dsfd_fts.append(self.dsfd_modules[2](c4)) | |||||
| dsfd_fts.append(self.dsfd_modules[3](feature_list[-1])) | |||||
| dsfd_fts.append(self.dsfd_modules[4](c6)) | |||||
| dsfd_fts.append(self.dsfd_modules[5](c7)) | |||||
| p4 = self.ff_c5_c4(c5, c4) | |||||
| p3 = self.ff_c4_c3(p4, c3) | |||||
| p2 = self.ff_c3_c2(p3, c2) | |||||
| p2 = self.p2_lat(p2) | |||||
| p3 = self.p3_lat(p3) | |||||
| p4 = self.p4_lat(p4) | |||||
| if self.out_dsfd_ft: | |||||
| return ([p2, p3, p4, c5, c6, c7], dsfd_fts) | |||||
| @@ -0,0 +1,164 @@ | |||||
| # -------------------------------------------------------- | |||||
| # The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on | |||||
| # https://github.com/damo-cv/MogFace | |||||
| # -------------------------------------------------------- | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class conv_bn(nn.Module): | |||||
| """docstring for conv""" | |||||
| def __init__(self, in_plane, out_plane, kernel_size, stride, padding): | |||||
| super(conv_bn, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_plane, | |||||
| out_plane, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding) | |||||
| self.bn1 = nn.BatchNorm2d(out_plane) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| return self.bn1(x) | |||||
| class SSHContext(nn.Module): | |||||
| def __init__(self, channels, Xchannels=256): | |||||
| super(SSHContext, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| channels, Xchannels, kernel_size=3, stride=1, padding=1) | |||||
| self.conv2 = nn.Conv2d( | |||||
| channels, | |||||
| Xchannels // 2, | |||||
| kernel_size=3, | |||||
| dilation=2, | |||||
| stride=1, | |||||
| padding=2) | |||||
| self.conv2_1 = nn.Conv2d( | |||||
| Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) | |||||
| self.conv2_2 = nn.Conv2d( | |||||
| Xchannels // 2, | |||||
| Xchannels // 2, | |||||
| kernel_size=3, | |||||
| dilation=2, | |||||
| stride=1, | |||||
| padding=2) | |||||
| self.conv2_2_1 = nn.Conv2d( | |||||
| Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) | |||||
| def forward(self, x): | |||||
| x1 = F.relu(self.conv1(x), inplace=True) | |||||
| x2 = F.relu(self.conv2(x), inplace=True) | |||||
| x2_1 = F.relu(self.conv2_1(x2), inplace=True) | |||||
| x2_2 = F.relu(self.conv2_2(x2), inplace=True) | |||||
| x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True) | |||||
| return torch.cat([x1, x2_1, x2_2], 1) | |||||
| class DeepHead(nn.Module): | |||||
| def __init__(self, | |||||
| in_channel=256, | |||||
| out_channel=256, | |||||
| use_gn=False, | |||||
| num_conv=4): | |||||
| super(DeepHead, self).__init__() | |||||
| self.use_gn = use_gn | |||||
| self.num_conv = num_conv | |||||
| self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1) | |||||
| self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||||
| self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||||
| self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) | |||||
| if self.use_gn: | |||||
| self.gn1 = nn.GroupNorm(16, out_channel) | |||||
| self.gn2 = nn.GroupNorm(16, out_channel) | |||||
| self.gn3 = nn.GroupNorm(16, out_channel) | |||||
| self.gn4 = nn.GroupNorm(16, out_channel) | |||||
| def forward(self, x): | |||||
| if self.use_gn: | |||||
| x1 = F.relu(self.gn1(self.conv1(x)), inplace=True) | |||||
| x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True) | |||||
| x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True) | |||||
| x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True) | |||||
| else: | |||||
| x1 = F.relu(self.conv1(x), inplace=True) | |||||
| x2 = F.relu(self.conv1(x1), inplace=True) | |||||
| if self.num_conv == 2: | |||||
| return x2 | |||||
| x3 = F.relu(self.conv1(x2), inplace=True) | |||||
| x4 = F.relu(self.conv1(x3), inplace=True) | |||||
| return x4 | |||||
| class MogPredNet(nn.Module): | |||||
| def __init__(self, | |||||
| num_anchor_per_pixel=1, | |||||
| num_classes=1, | |||||
| input_ch_list=[256, 256, 256, 256, 256, 256], | |||||
| use_deep_head=True, | |||||
| deep_head_with_gn=True, | |||||
| use_ssh=True, | |||||
| deep_head_ch=512): | |||||
| super(MogPredNet, self).__init__() | |||||
| self.num_classes = num_classes | |||||
| self.use_deep_head = use_deep_head | |||||
| self.deep_head_with_gn = deep_head_with_gn | |||||
| self.use_ssh = use_ssh | |||||
| self.deep_head_ch = deep_head_ch | |||||
| if self.use_ssh: | |||||
| self.conv_SSH = SSHContext(input_ch_list[0], | |||||
| self.deep_head_ch // 2) | |||||
| if self.use_deep_head: | |||||
| if self.deep_head_with_gn: | |||||
| self.deep_loc_head = DeepHead( | |||||
| self.deep_head_ch, self.deep_head_ch, use_gn=True) | |||||
| self.deep_cls_head = DeepHead( | |||||
| self.deep_head_ch, self.deep_head_ch, use_gn=True) | |||||
| self.pred_cls = nn.Conv2d(self.deep_head_ch, | |||||
| 1 * num_anchor_per_pixel, 3, 1, 1) | |||||
| self.pred_loc = nn.Conv2d(self.deep_head_ch, | |||||
| 4 * num_anchor_per_pixel, 3, 1, 1) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def forward(self, pyramid_feature_list, dsfd_ft_list=None): | |||||
| loc = [] | |||||
| conf = [] | |||||
| if self.use_deep_head: | |||||
| for x in pyramid_feature_list: | |||||
| if self.use_ssh: | |||||
| x = self.conv_SSH(x) | |||||
| x_cls = self.deep_cls_head(x) | |||||
| x_loc = self.deep_loc_head(x) | |||||
| conf.append( | |||||
| self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous()) | |||||
| loc.append( | |||||
| self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous()) | |||||
| loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) | |||||
| conf = torch.cat( | |||||
| [o.view(o.size(0), -1, self.num_classes) for o in conf], 1) | |||||
| output = ( | |||||
| self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)), | |||||
| loc.view(loc.size(0), -1, 4), | |||||
| ) | |||||
| return output | |||||
| @@ -0,0 +1,193 @@ | |||||
| # The implementation is modified from original resent implementaiton, which is | |||||
| # also open-sourced by the authors as Yang Liu, | |||||
| # and is available publicly on https://github.com/damo-cv/MogFace | |||||
| import torch.nn as nn | |||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||||
| """3x3 convolution with padding""" | |||||
| return nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=dilation, | |||||
| groups=groups, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| def conv1x1(in_planes, out_planes, stride=1): | |||||
| """1x1 convolution""" | |||||
| return nn.Conv2d( | |||||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, | |||||
| inplanes, | |||||
| planes, | |||||
| stride=1, | |||||
| downsample=None, | |||||
| groups=1, | |||||
| base_width=64, | |||||
| dilation=1, | |||||
| norm_layer=None): | |||||
| super(Bottleneck, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| width = int(planes * (base_width / 64.)) * groups | |||||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv1x1(inplanes, width) | |||||
| self.bn1 = norm_layer(width) | |||||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
| self.bn2 = norm_layer(width) | |||||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||||
| self.bn3 = norm_layer(planes * self.expansion) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self, | |||||
| depth=50, | |||||
| groups=1, | |||||
| width_per_group=64, | |||||
| replace_stride_with_dilation=None, | |||||
| norm_layer=None, | |||||
| inplanes=64, | |||||
| shrink_ch_ratio=1): | |||||
| super(ResNet, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| self._norm_layer = norm_layer | |||||
| if depth == 50: | |||||
| block = Bottleneck | |||||
| layers = [3, 4, 6, 3] | |||||
| elif depth == 101: | |||||
| block = Bottleneck | |||||
| layers = [3, 4, 23, 3] | |||||
| elif depth == 152: | |||||
| block = Bottleneck | |||||
| layers = [3, 4, 36, 3] | |||||
| elif depth == 18: | |||||
| block = BasicBlock | |||||
| layers = [2, 2, 2, 2] | |||||
| else: | |||||
| raise ValueError('only support depth in [18, 50, 101, 152]') | |||||
| shrink_input_ch = int(inplanes * shrink_ch_ratio) | |||||
| self.inplanes = int(inplanes * shrink_ch_ratio) | |||||
| if shrink_ch_ratio == 0.125: | |||||
| layers = [2, 3, 3, 3] | |||||
| self.dilation = 1 | |||||
| if replace_stride_with_dilation is None: | |||||
| # each element in the tuple indicates if we should replace | |||||
| # the 2x2 stride with a dilated convolution instead | |||||
| replace_stride_with_dilation = [False, False, False] | |||||
| if len(replace_stride_with_dilation) != 3: | |||||
| raise ValueError('replace_stride_with_dilation should be None ' | |||||
| 'or a 3-element tuple, got {}'.format( | |||||
| replace_stride_with_dilation)) | |||||
| self.groups = groups | |||||
| self.base_width = width_per_group | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = norm_layer(self.inplanes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, shrink_input_ch, layers[0]) | |||||
| self.layer2 = self._make_layer( | |||||
| block, | |||||
| shrink_input_ch * 2, | |||||
| layers[1], | |||||
| stride=2, | |||||
| dilate=replace_stride_with_dilation[0]) | |||||
| self.layer3 = self._make_layer( | |||||
| block, | |||||
| shrink_input_ch * 4, | |||||
| layers[2], | |||||
| stride=2, | |||||
| dilate=replace_stride_with_dilation[1]) | |||||
| self.layer4 = self._make_layer( | |||||
| block, | |||||
| shrink_input_ch * 8, | |||||
| layers[3], | |||||
| stride=2, | |||||
| dilate=replace_stride_with_dilation[2]) | |||||
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||||
| norm_layer = self._norm_layer | |||||
| downsample = None | |||||
| previous_dilation = self.dilation | |||||
| if dilate: | |||||
| self.dilation *= stride | |||||
| stride = 1 | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
| norm_layer(planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append( | |||||
| block(self.inplanes, planes, stride, downsample, self.groups, | |||||
| self.base_width, previous_dilation, norm_layer)) | |||||
| self.inplanes = planes * block.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append( | |||||
| block( | |||||
| self.inplanes, | |||||
| planes, | |||||
| groups=self.groups, | |||||
| base_width=self.base_width, | |||||
| dilation=self.dilation, | |||||
| norm_layer=norm_layer)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x = self.maxpool(x) | |||||
| four_conv_layer = [] | |||||
| x = self.layer1(x) | |||||
| four_conv_layer.append(x) | |||||
| x = self.layer2(x) | |||||
| four_conv_layer.append(x) | |||||
| x = self.layer3(x) | |||||
| four_conv_layer.append(x) | |||||
| x = self.layer4(x) | |||||
| four_conv_layer.append(x) | |||||
| return four_conv_layer | |||||
| @@ -0,0 +1,212 @@ | |||||
| # Modified from https://github.com/biubug6/Pytorch_Retinaface | |||||
| import math | |||||
| from itertools import product as product | |||||
| from math import ceil | |||||
| import numpy as np | |||||
| import torch | |||||
| def transform_anchor(anchors): | |||||
| """ | |||||
| from [x0, x1, y0, y1] to [c_x, cy, w, h] | |||||
| x1 = x0 + w - 1 | |||||
| c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2 | |||||
| """ | |||||
| return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2, | |||||
| anchors[:, 2:] - anchors[:, :2] + 1), | |||||
| axis=1) | |||||
| def normalize_anchor(anchors): | |||||
| """ | |||||
| from [c_x, cy, w, h] to [x0, x1, y0, y1] | |||||
| """ | |||||
| item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2 | |||||
| item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2 | |||||
| return np.concatenate((item_1, item_2), axis=1) | |||||
| class MogPriorBox(object): | |||||
| """ | |||||
| both for fpn and single layer, single layer need to test | |||||
| return (np.array) [num_anchros, 4] [x0, y0, x1, y1] | |||||
| """ | |||||
| def __init__(self, | |||||
| scale_list=[1.], | |||||
| aspect_ratio_list=[1.0], | |||||
| stride_list=[4, 8, 16, 32, 64, 128], | |||||
| anchor_size_list=[16, 32, 64, 128, 256, 512]): | |||||
| self.scale_list = scale_list | |||||
| self.aspect_ratio_list = aspect_ratio_list | |||||
| self.stride_list = stride_list | |||||
| self.anchor_size_list = anchor_size_list | |||||
| def __call__(self, img_height, img_width): | |||||
| final_anchor_list = [] | |||||
| for idx, stride in enumerate(self.stride_list): | |||||
| anchor_list = [] | |||||
| cur_img_height = img_height | |||||
| cur_img_width = img_width | |||||
| tmp_stride = stride | |||||
| while tmp_stride != 1: | |||||
| tmp_stride = tmp_stride // 2 | |||||
| cur_img_height = (cur_img_height + 1) // 2 | |||||
| cur_img_width = (cur_img_width + 1) // 2 | |||||
| for i in range(cur_img_height): | |||||
| for j in range(cur_img_width): | |||||
| for scale in self.scale_list: | |||||
| cx = (j + 0.5) * stride | |||||
| cy = (i + 0.5) * stride | |||||
| side_x = self.anchor_size_list[idx] * scale | |||||
| side_y = self.anchor_size_list[idx] * scale | |||||
| for ratio in self.aspect_ratio_list: | |||||
| anchor_list.append([ | |||||
| cx, cy, side_x / math.sqrt(ratio), | |||||
| side_y * math.sqrt(ratio) | |||||
| ]) | |||||
| final_anchor_list.append(anchor_list) | |||||
| final_anchor_arr = np.concatenate(final_anchor_list, axis=0) | |||||
| normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype( | |||||
| 'float32') | |||||
| transformed_anchor = transform_anchor(normalized_anchor_arr) | |||||
| return transformed_anchor | |||||
| class PriorBox(object): | |||||
| def __init__(self, cfg, image_size=None, phase='train'): | |||||
| super(PriorBox, self).__init__() | |||||
| self.min_sizes = cfg['min_sizes'] | |||||
| self.steps = cfg['steps'] | |||||
| self.clip = cfg['clip'] | |||||
| self.image_size = image_size | |||||
| self.feature_maps = [[ | |||||
| ceil(self.image_size[0] / step), | |||||
| ceil(self.image_size[1] / step) | |||||
| ] for step in self.steps] | |||||
| self.name = 's' | |||||
| def forward(self): | |||||
| anchors = [] | |||||
| for k, f in enumerate(self.feature_maps): | |||||
| min_sizes = self.min_sizes[k] | |||||
| for i, j in product(range(f[0]), range(f[1])): | |||||
| for min_size in min_sizes: | |||||
| s_kx = min_size / self.image_size[1] | |||||
| s_ky = min_size / self.image_size[0] | |||||
| dense_cx = [ | |||||
| x * self.steps[k] / self.image_size[1] | |||||
| for x in [j + 0.5] | |||||
| ] | |||||
| dense_cy = [ | |||||
| y * self.steps[k] / self.image_size[0] | |||||
| for y in [i + 0.5] | |||||
| ] | |||||
| for cy, cx in product(dense_cy, dense_cx): | |||||
| anchors += [cx, cy, s_kx, s_ky] | |||||
| # back to torch land | |||||
| output = torch.Tensor(anchors).view(-1, 4) | |||||
| if self.clip: | |||||
| output.clamp_(max=1, min=0) | |||||
| return output | |||||
| def py_cpu_nms(dets, thresh): | |||||
| """Pure Python NMS baseline.""" | |||||
| x1 = dets[:, 0] | |||||
| y1 = dets[:, 1] | |||||
| x2 = dets[:, 2] | |||||
| y2 = dets[:, 3] | |||||
| scores = dets[:, 4] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||||
| order = scores.argsort()[::-1] | |||||
| keep = [] | |||||
| while order.size > 0: | |||||
| i = order[0] | |||||
| keep.append(i) | |||||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||||
| inter = w * h | |||||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||||
| inds = np.where(ovr <= thresh)[0] | |||||
| order = order[inds + 1] | |||||
| return keep | |||||
| def mogdecode(loc, anchors): | |||||
| """ | |||||
| loc: torch.Tensor | |||||
| anchors: 2-d, torch.Tensor (cx, cy, w, h) | |||||
| boxes: 2-d, torch.Tensor (x0, y0, x1, y1) | |||||
| """ | |||||
| boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:], | |||||
| anchors[:, 2:] * torch.exp(loc[:, 2:])), 1) | |||||
| boxes[:, 0] -= (boxes[:, 2] - 1) / 2 | |||||
| boxes[:, 1] -= (boxes[:, 3] - 1) / 2 | |||||
| boxes[:, 2] += boxes[:, 0] - 1 | |||||
| boxes[:, 3] += boxes[:, 1] - 1 | |||||
| return boxes | |||||
| # Adapted from https://github.com/Hakuyume/chainer-ssd | |||||
| def decode(loc, priors, variances): | |||||
| """Decode locations from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| loc (tensor): location predictions for loc layers, | |||||
| Shape: [num_priors,4] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded bounding box predictions | |||||
| """ | |||||
| boxes = torch.cat( | |||||
| (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||||
| priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||||
| boxes[:, :2] -= boxes[:, 2:] / 2 | |||||
| boxes[:, 2:] += boxes[:, :2] | |||||
| return boxes | |||||
| def decode_landm(pre, priors, variances): | |||||
| """Decode landm from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| pre (tensor): landm predictions for loc layers, | |||||
| Shape: [num_priors,10] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded landm predictions | |||||
| """ | |||||
| a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||||
| b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||||
| c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||||
| d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||||
| e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||||
| landms = torch.cat((a, b, c, d, e), dim=1) | |||||
| return landms | |||||
| @@ -48,6 +48,7 @@ if TYPE_CHECKING: | |||||
| from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline | ||||
| from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline | ||||
| from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline | ||||
| from .mog_face_detection_pipeline import MogFaceDetectionPipeline | |||||
| from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | ||||
| from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | ||||
| from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | ||||
| @@ -112,6 +113,7 @@ else: | |||||
| ['TextDrivenSegmentationPipeline'], | ['TextDrivenSegmentationPipeline'], | ||||
| 'movie_scene_segmentation_pipeline': | 'movie_scene_segmentation_pipeline': | ||||
| ['MovieSceneSegmentationPipeline'], | ['MovieSceneSegmentationPipeline'], | ||||
| 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], | |||||
| 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | ||||
| 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | ||||
| 'facial_expression_recognition_pipelin': | 'facial_expression_recognition_pipelin': | ||||
| @@ -0,0 +1,54 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.face_detection import MogFaceDetector | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.face_detection, module_name=Pipelines.mog_face_detection) | |||||
| class MogFaceDetectionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a face detection pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| detector = MogFaceDetector(model_path=ckpt_path, device=self.device) | |||||
| self.detector = detector | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = img.astype(np.float32) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.detector(input) | |||||
| assert result is not None | |||||
| bboxes = result[:, :4].tolist() | |||||
| scores = result[:, 4].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: None, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class MogFaceDetectionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface' | |||||
| def show_result(self, img_path, detection_result): | |||||
| img = draw_face_detection_no_lm_result(img_path, detection_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/mog_face_detection.jpg' | |||||
| result = face_detection(img_path) | |||||
| self.show_result(img_path, result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||