ly261666 wenmeng.zwm 3 years ago
parent
commit
9d43823f36
10 changed files with 389 additions and 5 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +2
    -1
      modelscope/models/cv/face_detection/__init__.py
  3. +1
    -0
      modelscope/models/cv/face_detection/scrfd/__init__.py
  4. +2
    -1
      modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py
  5. +99
    -0
      modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py
  6. +2
    -1
      modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py
  7. +148
    -0
      modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py
  8. +67
    -0
      modelscope/models/cv/face_detection/scrfd/tinymog_detect.py
  9. +9
    -2
      modelscope/pipelines/cv/face_detection_pipeline.py
  10. +57
    -0
      tests/pipelines/test_tinymog_face_detection.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -47,6 +47,7 @@ class Models(object):
ulfd = 'ulfd' ulfd = 'ulfd'
arcface = 'arcface' arcface = 'arcface'
facemask = 'facemask' facemask = 'facemask'
tinymog = 'tinymog'
video_inpainting = 'video-inpainting' video_inpainting = 'video-inpainting'
human_wholebody_keypoint = 'human-wholebody-keypoint' human_wholebody_keypoint = 'human-wholebody-keypoint'
hand_static = 'hand-static' hand_static = 'hand-static'
@@ -182,6 +183,7 @@ class Pipelines(object):
face_detection = 'resnet-face-detection-scrfd10gkps' face_detection = 'resnet-face-detection-scrfd10gkps'
card_detection = 'resnet-card-detection-scrfd34gkps' card_detection = 'resnet-card-detection-scrfd34gkps'
ulfd_face_detection = 'manual-face-detection-ulfd' ulfd_face_detection = 'manual-face-detection-ulfd'
tinymog_face_detection = 'manual-face-detection-tinymog'
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
retina_face_detection = 'resnet50-face-detection-retinaface' retina_face_detection = 'resnet50-face-detection-retinaface'
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' mog_face_detection = 'resnet101-face-detection-cvpr22papermogface'


+ 2
- 1
modelscope/models/cv/face_detection/__init__.py View File

@@ -9,13 +9,14 @@ if TYPE_CHECKING:
from .retinaface import RetinaFaceDetection from .retinaface import RetinaFaceDetection
from .ulfd_slim import UlfdFaceDetector from .ulfd_slim import UlfdFaceDetector
from .scrfd import ScrfdDetect from .scrfd import ScrfdDetect
from .scrfd import TinyMogDetect
else: else:
_import_structure = { _import_structure = {
'ulfd_slim': ['UlfdFaceDetector'], 'ulfd_slim': ['UlfdFaceDetector'],
'retinaface': ['RetinaFaceDetection'], 'retinaface': ['RetinaFaceDetection'],
'mtcnn': ['MtcnnFaceDetector'], 'mtcnn': ['MtcnnFaceDetector'],
'mogface': ['MogFaceDetector'], 'mogface': ['MogFaceDetector'],
'scrfd': ['ScrfdDetect']
'scrfd': ['TinyMogDetect', 'ScrfdDetect'],
} }


import sys import sys


+ 1
- 0
modelscope/models/cv/face_detection/scrfd/__init__.py View File

@@ -1,2 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .scrfd_detect import ScrfdDetect from .scrfd_detect import ScrfdDetect
from .tinymog_detect import TinyMogDetect

+ 2
- 1
modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py View File

@@ -2,6 +2,7 @@
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones
""" """
from .mobilenet import MobileNetV1
from .resnet import ResNetV1e from .resnet import ResNetV1e


__all__ = ['ResNetV1e']
__all__ = ['ResNetV1e', 'MobileNetV1']

+ 99
- 0
modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/mobilenet.py View File

@@ -0,0 +1,99 @@
"""
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/mobilenet.py
"""

import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmdet.models.builder import BACKBONES
from mmdet.utils import get_root_logger
from torch.nn.modules.batchnorm import _BatchNorm


@BACKBONES.register_module()
class MobileNetV1(nn.Module):

def __init__(self,
in_channels=3,
block_cfg=None,
num_stages=4,
out_indices=(0, 1, 2, 3)):
super(MobileNetV1, self).__init__()
self.out_indices = out_indices

def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup), nn.ReLU(inplace=True))

def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)

if block_cfg is None:
stage_planes = [8, 16, 32, 64, 128, 256]
stage_blocks = [2, 4, 4, 2]
else:
stage_planes = block_cfg['stage_planes']
stage_blocks = block_cfg['stage_blocks']
assert len(stage_planes) == 6
assert len(stage_blocks) == 4
self.stem = nn.Sequential(
conv_bn(3, stage_planes[0], 2),
conv_dw(stage_planes[0], stage_planes[1], 1),
)
self.stage_layers = []
for i, num_blocks in enumerate(stage_blocks):
_layers = []
for n in range(num_blocks):
if n == 0:
_layer = conv_dw(stage_planes[i + 1], stage_planes[i + 2],
2)
else:
_layer = conv_dw(stage_planes[i + 2], stage_planes[i + 2],
1)
_layers.append(_layer)

_block = nn.Sequential(*_layers)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, _block)
self.stage_layers.append(layer_name)

def forward(self, x):
output = []
x = self.stem(x)
for i, layer_name in enumerate(self.stage_layers):
stage_layer = getattr(self, layer_name)
x = stage_layer(x)
if i in self.out_indices:
output.append(x)

return tuple(output)

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')

+ 2
- 1
modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py View File

@@ -3,5 +3,6 @@ The implementation here is modified based on insightface, originally MIT license
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors
""" """
from .scrfd import SCRFD from .scrfd import SCRFD
from .tinymog import TinyMog


__all__ = ['SCRFD']
__all__ = ['SCRFD', 'TinyMog']

+ 148
- 0
modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/tinymog.py View File

@@ -0,0 +1,148 @@
"""
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py
"""
import torch
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector

from ....mmdet_patch.core.bbox import bbox2result


@DETECTORS.register_module()
class TinyMog(SingleStageDetector):

def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TinyMog, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)

def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_keypointss=None,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.

Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_keypointss,
gt_bboxes_ignore)
return losses

def simple_test(self,
img,
img_metas,
rescale=False,
repeat_head=1,
output_kps_var=0,
output_results=1):
"""Test function without test time augmentation.

Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
repeat_head (int): repeat inference times in head
output_kps_var (int): whether output kps var to calculate quality
output_results (int): 0: nothing 1: bbox 2: both bbox and kps

Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
assert repeat_head >= 1
kps_out0 = []
kps_out1 = []
kps_out2 = []
for i in range(repeat_head):
outs = self.bbox_head(x)
kps_out0 += [outs[2][0].detach().cpu().numpy()]
kps_out1 += [outs[2][1].detach().cpu().numpy()]
kps_out2 += [outs[2][2].detach().cpu().numpy()]
if output_kps_var:
var0 = np.var(np.vstack(kps_out0), axis=0).mean()
var1 = np.var(np.vstack(kps_out1), axis=0).mean()
var2 = np.var(np.vstack(kps_out2), axis=0).mean()
var = np.mean([var0, var1, var2])
else:
var = None

if output_results > 0:
if torch.onnx.is_in_onnx_export():
cls_score, bbox_pred, kps_pred = outs
for c in cls_score:
print(c.shape)
for c in bbox_pred:
print(c.shape)
if self.bbox_head.use_kps:
for c in kps_pred:
print(c.shape)
return (cls_score, bbox_pred, kps_pred)
else:
return (cls_score, bbox_pred)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)

# return kps if use_kps
if len(bbox_list[0]) == 2:
bbox_results = [
bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
elif len(bbox_list[0]) == 3:
if output_results == 2:
bbox_results = [
bbox2result(
det_bboxes,
det_labels,
self.bbox_head.num_classes,
kps=det_kps,
num_kps=self.bbox_head.NK)
for det_bboxes, det_labels, det_kps in bbox_list
]
elif output_results == 1:
bbox_results = [
bbox2result(det_bboxes, det_labels,
self.bbox_head.num_classes)
for det_bboxes, det_labels, _ in bbox_list
]
else:
bbox_results = None
if var is not None:
return bbox_results, var
else:
return bbox_results

def feature_test(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs

+ 67
- 0
modelscope/models/cv/face_detection/scrfd/tinymog_detect.py View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from copy import deepcopy
from typing import Any, Dict

import torch

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['TinyMogDetect']


@MODELS.register_module(Tasks.face_detection, module_name=Models.tinymog)
class TinyMogDetect(TorchModel):

def __init__(self, model_dir, *args, **kwargs):
"""
initialize the tinymog face detection model from the `model_dir` path.
"""
super().__init__(model_dir)
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD
cfg = Config.fromfile(osp.join(model_dir, 'mmcv_tinymog.py'))
ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3)
detector = build_detector(cfg.model)
logger.info(f'loading model from {ckpt_path}')
load_checkpoint(detector, ckpt_path, map_location='cpu')
detector = MMDataParallel(detector)
detector.eval()
self.detector = detector
logger.info('load model done')

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
result = self.detector(
return_loss=False,
rescale=True,
img=[input['img'][0].unsqueeze(0)],
img_metas=[[dict(input['img_metas'][0].data)]],
output_results=2)
assert result is not None
result = result[0][0]
bboxes = result[:, :4].tolist()
kpss = result[:, 5:].tolist()
scores = result[:, 4].tolist()
return {
OutputKeys.SCORES: scores,
OutputKeys.BOXES: bboxes,
OutputKeys.KEYPOINTS: kpss
}

def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return input

+ 9
- 2
modelscope/pipelines/cv/face_detection_pipeline.py View File

@@ -8,11 +8,12 @@ import PIL
import torch import torch


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_detection import ScrfdDetect
from modelscope.models.cv.face_detection import ScrfdDetect, TinyMogDetect
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -30,7 +31,13 @@ class FaceDetectionPipeline(Pipeline):
model: model id on modelscope hub. model: model id on modelscope hub.
""" """
super().__init__(model=model, **kwargs) super().__init__(model=model, **kwargs)
detector = ScrfdDetect(model_dir=model, **kwargs)
config_path = osp.join(model, ModelFile.CONFIGURATION)
cfg = Config.from_file(config_path)
cfg_model = getattr(cfg, 'model', None)
if cfg_model is None:
detector = ScrfdDetect(model_dir=model, **kwargs)
elif cfg_model.type == 'tinymog':
detector = self.model.to(self.device)
self.detector = detector self.detector = detector


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:


+ 57
- 0
tests/pipelines/test_tinymog_face_detection.py View File

@@ -0,0 +1,57 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2

from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import draw_face_detection_result
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class TinyMogFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.face_detection
self.model_id = 'damo/cv_manual_face-detection_tinymog'
self.img_path = 'data/test/images/mog_face_detection.jpg'

def show_result(self, img_path, detection_result):
img = draw_face_detection_result(img_path, detection_result)
cv2.imwrite('result.png', img)
print(f'output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_dataset(self):
input_location = ['data/test/images/mog_face_detection.jpg']

dataset = MsDataset.load(input_location, target='image')
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = face_detection(dataset)
result = next(result)
self.show_result(input_location[0], result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
face_detection = pipeline(Tasks.face_detection, model=self.model_id)

result = face_detection(self.img_path)
self.show_result(self.img_path, result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
face_detection = pipeline(Tasks.face_detection)
result = face_detection(self.img_path)
self.show_result(self.img_path, result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save