Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10792564master^2
| @@ -47,6 +47,7 @@ class Models(object): | |||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| arcface = 'arcface' | arcface = 'arcface' | ||||
| facemask = 'facemask' | facemask = 'facemask' | ||||
| tinymog = 'tinymog' | |||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | human_wholebody_keypoint = 'human-wholebody-keypoint' | ||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| @@ -182,6 +183,7 @@ class Pipelines(object): | |||||
| face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
| card_detection = 'resnet-card-detection-scrfd34gkps' | card_detection = 'resnet-card-detection-scrfd34gkps' | ||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| tinymog_face_detection = 'manual-face-detection-tinymog' | |||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | ||||
| @@ -9,13 +9,14 @@ if TYPE_CHECKING: | |||||
| from .retinaface import RetinaFaceDetection | from .retinaface import RetinaFaceDetection | ||||
| from .ulfd_slim import UlfdFaceDetector | from .ulfd_slim import UlfdFaceDetector | ||||
| from .scrfd import ScrfdDetect | from .scrfd import ScrfdDetect | ||||
| from .scrfd import TinyMogDetect | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'ulfd_slim': ['UlfdFaceDetector'], | 'ulfd_slim': ['UlfdFaceDetector'], | ||||
| 'retinaface': ['RetinaFaceDetection'], | 'retinaface': ['RetinaFaceDetection'], | ||||
| 'mtcnn': ['MtcnnFaceDetector'], | 'mtcnn': ['MtcnnFaceDetector'], | ||||
| 'mogface': ['MogFaceDetector'], | 'mogface': ['MogFaceDetector'], | ||||
| 'scrfd': ['ScrfdDetect'] | |||||
| 'scrfd': ['TinyMogDetect', 'ScrfdDetect'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -1,2 +1,3 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .scrfd_detect import ScrfdDetect | from .scrfd_detect import ScrfdDetect | ||||
| from .tinymog_detect import TinyMogDetect | |||||
| @@ -2,6 +2,7 @@ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | ||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones | https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones | ||||
| """ | """ | ||||
| from .mobilenet import MobileNetV1 | |||||
| from .resnet import ResNetV1e | from .resnet import ResNetV1e | ||||
| __all__ = ['ResNetV1e'] | |||||
| __all__ = ['ResNetV1e', 'MobileNetV1'] | |||||
| @@ -0,0 +1,99 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/mobilenet.py | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, | |||||
| constant_init, kaiming_init) | |||||
| from mmcv.runner import load_checkpoint | |||||
| from mmdet.models.builder import BACKBONES | |||||
| from mmdet.utils import get_root_logger | |||||
| from torch.nn.modules.batchnorm import _BatchNorm | |||||
| @BACKBONES.register_module() | |||||
| class MobileNetV1(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels=3, | |||||
| block_cfg=None, | |||||
| num_stages=4, | |||||
| out_indices=(0, 1, 2, 3)): | |||||
| super(MobileNetV1, self).__init__() | |||||
| self.out_indices = out_indices | |||||
| def conv_bn(inp, oup, stride): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||||
| nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) | |||||
| def conv_dw(inp, oup, stride): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||||
| nn.BatchNorm2d(inp), | |||||
| nn.ReLU(inplace=True), | |||||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| nn.ReLU(inplace=True), | |||||
| ) | |||||
| if block_cfg is None: | |||||
| stage_planes = [8, 16, 32, 64, 128, 256] | |||||
| stage_blocks = [2, 4, 4, 2] | |||||
| else: | |||||
| stage_planes = block_cfg['stage_planes'] | |||||
| stage_blocks = block_cfg['stage_blocks'] | |||||
| assert len(stage_planes) == 6 | |||||
| assert len(stage_blocks) == 4 | |||||
| self.stem = nn.Sequential( | |||||
| conv_bn(3, stage_planes[0], 2), | |||||
| conv_dw(stage_planes[0], stage_planes[1], 1), | |||||
| ) | |||||
| self.stage_layers = [] | |||||
| for i, num_blocks in enumerate(stage_blocks): | |||||
| _layers = [] | |||||
| for n in range(num_blocks): | |||||
| if n == 0: | |||||
| _layer = conv_dw(stage_planes[i + 1], stage_planes[i + 2], | |||||
| 2) | |||||
| else: | |||||
| _layer = conv_dw(stage_planes[i + 2], stage_planes[i + 2], | |||||
| 1) | |||||
| _layers.append(_layer) | |||||
| _block = nn.Sequential(*_layers) | |||||
| layer_name = f'layer{i + 1}' | |||||
| self.add_module(layer_name, _block) | |||||
| self.stage_layers.append(layer_name) | |||||
| def forward(self, x): | |||||
| output = [] | |||||
| x = self.stem(x) | |||||
| for i, layer_name in enumerate(self.stage_layers): | |||||
| stage_layer = getattr(self, layer_name) | |||||
| x = stage_layer(x) | |||||
| if i in self.out_indices: | |||||
| output.append(x) | |||||
| return tuple(output) | |||||
| def init_weights(self, pretrained=None): | |||||
| """Initialize the weights in backbone. | |||||
| Args: | |||||
| pretrained (str, optional): Path to pre-trained weights. | |||||
| Defaults to None. | |||||
| """ | |||||
| if isinstance(pretrained, str): | |||||
| logger = get_root_logger() | |||||
| load_checkpoint(self, pretrained, strict=False, logger=logger) | |||||
| elif pretrained is None: | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| kaiming_init(m) | |||||
| elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |||||
| constant_init(m, 1) | |||||
| else: | |||||
| raise TypeError('pretrained must be a str or None') | |||||
| @@ -3,5 +3,6 @@ The implementation here is modified based on insightface, originally MIT license | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors | https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors | ||||
| """ | """ | ||||
| from .scrfd import SCRFD | from .scrfd import SCRFD | ||||
| from .tinymog import TinyMog | |||||
| __all__ = ['SCRFD'] | |||||
| __all__ = ['SCRFD', 'TinyMog'] | |||||
| @@ -0,0 +1,148 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py | |||||
| """ | |||||
| import torch | |||||
| from mmdet.models.builder import DETECTORS | |||||
| from mmdet.models.detectors.single_stage import SingleStageDetector | |||||
| from ....mmdet_patch.core.bbox import bbox2result | |||||
| @DETECTORS.register_module() | |||||
| class TinyMog(SingleStageDetector): | |||||
| def __init__(self, | |||||
| backbone, | |||||
| neck, | |||||
| bbox_head, | |||||
| train_cfg=None, | |||||
| test_cfg=None, | |||||
| pretrained=None): | |||||
| super(TinyMog, self).__init__(backbone, neck, bbox_head, train_cfg, | |||||
| test_cfg, pretrained) | |||||
| def forward_train(self, | |||||
| img, | |||||
| img_metas, | |||||
| gt_bboxes, | |||||
| gt_labels, | |||||
| gt_keypointss=None, | |||||
| gt_bboxes_ignore=None): | |||||
| """ | |||||
| Args: | |||||
| img (Tensor): Input images of shape (N, C, H, W). | |||||
| Typically these should be mean centered and std scaled. | |||||
| img_metas (list[dict]): A List of image info dict where each dict | |||||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||||
| For details on the values of these keys see | |||||
| :class:`mmdet.datasets.pipelines.Collect`. | |||||
| gt_bboxes (list[Tensor]): Each item are the truth boxes for each | |||||
| image in [tl_x, tl_y, br_x, br_y] format. | |||||
| gt_labels (list[Tensor]): Class indices corresponding to each box | |||||
| gt_bboxes_ignore (None | list[Tensor]): Specify which bounding | |||||
| boxes can be ignored when computing the loss. | |||||
| Returns: | |||||
| dict[str, Tensor]: A dictionary of loss components. | |||||
| """ | |||||
| super(SingleStageDetector, self).forward_train(img, img_metas) | |||||
| x = self.extract_feat(img) | |||||
| losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, | |||||
| gt_labels, gt_keypointss, | |||||
| gt_bboxes_ignore) | |||||
| return losses | |||||
| def simple_test(self, | |||||
| img, | |||||
| img_metas, | |||||
| rescale=False, | |||||
| repeat_head=1, | |||||
| output_kps_var=0, | |||||
| output_results=1): | |||||
| """Test function without test time augmentation. | |||||
| Args: | |||||
| imgs (list[torch.Tensor]): List of multiple images | |||||
| img_metas (list[dict]): List of image information. | |||||
| rescale (bool, optional): Whether to rescale the results. | |||||
| Defaults to False. | |||||
| repeat_head (int): repeat inference times in head | |||||
| output_kps_var (int): whether output kps var to calculate quality | |||||
| output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||||
| Returns: | |||||
| list[list[np.ndarray]]: BBox results of each image and classes. | |||||
| The outer list corresponds to each image. The inner list | |||||
| corresponds to each class. | |||||
| """ | |||||
| x = self.extract_feat(img) | |||||
| assert repeat_head >= 1 | |||||
| kps_out0 = [] | |||||
| kps_out1 = [] | |||||
| kps_out2 = [] | |||||
| for i in range(repeat_head): | |||||
| outs = self.bbox_head(x) | |||||
| kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||||
| kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||||
| kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||||
| if output_kps_var: | |||||
| var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||||
| var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||||
| var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||||
| var = np.mean([var0, var1, var2]) | |||||
| else: | |||||
| var = None | |||||
| if output_results > 0: | |||||
| if torch.onnx.is_in_onnx_export(): | |||||
| cls_score, bbox_pred, kps_pred = outs | |||||
| for c in cls_score: | |||||
| print(c.shape) | |||||
| for c in bbox_pred: | |||||
| print(c.shape) | |||||
| if self.bbox_head.use_kps: | |||||
| for c in kps_pred: | |||||
| print(c.shape) | |||||
| return (cls_score, bbox_pred, kps_pred) | |||||
| else: | |||||
| return (cls_score, bbox_pred) | |||||
| bbox_list = self.bbox_head.get_bboxes( | |||||
| *outs, img_metas, rescale=rescale) | |||||
| # return kps if use_kps | |||||
| if len(bbox_list[0]) == 2: | |||||
| bbox_results = [ | |||||
| bbox2result(det_bboxes, det_labels, | |||||
| self.bbox_head.num_classes) | |||||
| for det_bboxes, det_labels in bbox_list | |||||
| ] | |||||
| elif len(bbox_list[0]) == 3: | |||||
| if output_results == 2: | |||||
| bbox_results = [ | |||||
| bbox2result( | |||||
| det_bboxes, | |||||
| det_labels, | |||||
| self.bbox_head.num_classes, | |||||
| kps=det_kps, | |||||
| num_kps=self.bbox_head.NK) | |||||
| for det_bboxes, det_labels, det_kps in bbox_list | |||||
| ] | |||||
| elif output_results == 1: | |||||
| bbox_results = [ | |||||
| bbox2result(det_bboxes, det_labels, | |||||
| self.bbox_head.num_classes) | |||||
| for det_bboxes, det_labels, _ in bbox_list | |||||
| ] | |||||
| else: | |||||
| bbox_results = None | |||||
| if var is not None: | |||||
| return bbox_results, var | |||||
| else: | |||||
| return bbox_results | |||||
| def feature_test(self, img): | |||||
| x = self.extract_feat(img) | |||||
| outs = self.bbox_head(x) | |||||
| return outs | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from copy import deepcopy | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['TinyMogDetect'] | |||||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.tinymog) | |||||
| class TinyMogDetect(TorchModel): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| """ | |||||
| initialize the tinymog face detection model from the `model_dir` path. | |||||
| """ | |||||
| super().__init__(model_dir) | |||||
| from mmcv import Config | |||||
| from mmcv.parallel import MMDataParallel | |||||
| from mmcv.runner import load_checkpoint | |||||
| from mmdet.models import build_detector | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||||
| cfg = Config.fromfile(osp.join(model_dir, 'mmcv_tinymog.py')) | |||||
| ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||||
| detector = build_detector(cfg.model) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| load_checkpoint(detector, ckpt_path, map_location='cpu') | |||||
| detector = MMDataParallel(detector) | |||||
| detector.eval() | |||||
| self.detector = detector | |||||
| logger.info('load model done') | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.detector( | |||||
| return_loss=False, | |||||
| rescale=True, | |||||
| img=[input['img'][0].unsqueeze(0)], | |||||
| img_metas=[[dict(input['img_metas'][0].data)]], | |||||
| output_results=2) | |||||
| assert result is not None | |||||
| result = result[0][0] | |||||
| bboxes = result[:, :4].tolist() | |||||
| kpss = result[:, 5:].tolist() | |||||
| scores = result[:, 4].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: kpss | |||||
| } | |||||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -8,11 +8,12 @@ import PIL | |||||
| import torch | import torch | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.face_detection import ScrfdDetect | |||||
| from modelscope.models.cv.face_detection import ScrfdDetect, TinyMogDetect | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -30,7 +31,13 @@ class FaceDetectionPipeline(Pipeline): | |||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| detector = ScrfdDetect(model_dir=model, **kwargs) | |||||
| config_path = osp.join(model, ModelFile.CONFIGURATION) | |||||
| cfg = Config.from_file(config_path) | |||||
| cfg_model = getattr(cfg, 'model', None) | |||||
| if cfg_model is None: | |||||
| detector = ScrfdDetect(model_dir=model, **kwargs) | |||||
| elif cfg_model.type == 'tinymog': | |||||
| detector = self.model.to(self.device) | |||||
| self.detector = detector | self.detector = detector | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| @@ -0,0 +1,57 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_face_detection_result | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TinyMogFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.face_detection | |||||
| self.model_id = 'damo/cv_manual_face-detection_tinymog' | |||||
| self.img_path = 'data/test/images/mog_face_detection.jpg' | |||||
| def show_result(self, img_path, detection_result): | |||||
| img = draw_face_detection_result(img_path, detection_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_dataset(self): | |||||
| input_location = ['data/test/images/mog_face_detection.jpg'] | |||||
| dataset = MsDataset.load(input_location, target='image') | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||||
| result = face_detection(dataset) | |||||
| result = next(result) | |||||
| self.show_result(input_location[0], result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| result = face_detection(self.img_path) | |||||
| self.show_result(self.img_path, result) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| face_detection = pipeline(Tasks.face_detection) | |||||
| result = face_detection(self.img_path) | |||||
| self.show_result(self.img_path, result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||