Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10792564master^2
| @@ -47,6 +47,7 @@ class Models(object): | |||
| ulfd = 'ulfd' | |||
| arcface = 'arcface' | |||
| facemask = 'facemask' | |||
| tinymog = 'tinymog' | |||
| video_inpainting = 'video-inpainting' | |||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||
| hand_static = 'hand-static' | |||
| @@ -182,6 +183,7 @@ class Pipelines(object): | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| card_detection = 'resnet-card-detection-scrfd34gkps' | |||
| ulfd_face_detection = 'manual-face-detection-ulfd' | |||
| tinymog_face_detection = 'manual-face-detection-tinymog' | |||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||
| mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | |||
| @@ -9,13 +9,14 @@ if TYPE_CHECKING: | |||
| from .retinaface import RetinaFaceDetection | |||
| from .ulfd_slim import UlfdFaceDetector | |||
| from .scrfd import ScrfdDetect | |||
| from .scrfd import TinyMogDetect | |||
| else: | |||
| _import_structure = { | |||
| 'ulfd_slim': ['UlfdFaceDetector'], | |||
| 'retinaface': ['RetinaFaceDetection'], | |||
| 'mtcnn': ['MtcnnFaceDetector'], | |||
| 'mogface': ['MogFaceDetector'], | |||
| 'scrfd': ['ScrfdDetect'] | |||
| 'scrfd': ['TinyMogDetect', 'ScrfdDetect'], | |||
| } | |||
| import sys | |||
| @@ -1,2 +1,3 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .scrfd_detect import ScrfdDetect | |||
| from .tinymog_detect import TinyMogDetect | |||
| @@ -2,6 +2,7 @@ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones | |||
| """ | |||
| from .mobilenet import MobileNetV1 | |||
| from .resnet import ResNetV1e | |||
| __all__ = ['ResNetV1e'] | |||
| __all__ = ['ResNetV1e', 'MobileNetV1'] | |||
| @@ -0,0 +1,99 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/mobilenet.py | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, | |||
| constant_init, kaiming_init) | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.models.builder import BACKBONES | |||
| from mmdet.utils import get_root_logger | |||
| from torch.nn.modules.batchnorm import _BatchNorm | |||
| @BACKBONES.register_module() | |||
| class MobileNetV1(nn.Module): | |||
| def __init__(self, | |||
| in_channels=3, | |||
| block_cfg=None, | |||
| num_stages=4, | |||
| out_indices=(0, 1, 2, 3)): | |||
| super(MobileNetV1, self).__init__() | |||
| self.out_indices = out_indices | |||
| def conv_bn(inp, oup, stride): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
| nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) | |||
| def conv_dw(inp, oup, stride): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||
| nn.BatchNorm2d(inp), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(oup), | |||
| nn.ReLU(inplace=True), | |||
| ) | |||
| if block_cfg is None: | |||
| stage_planes = [8, 16, 32, 64, 128, 256] | |||
| stage_blocks = [2, 4, 4, 2] | |||
| else: | |||
| stage_planes = block_cfg['stage_planes'] | |||
| stage_blocks = block_cfg['stage_blocks'] | |||
| assert len(stage_planes) == 6 | |||
| assert len(stage_blocks) == 4 | |||
| self.stem = nn.Sequential( | |||
| conv_bn(3, stage_planes[0], 2), | |||
| conv_dw(stage_planes[0], stage_planes[1], 1), | |||
| ) | |||
| self.stage_layers = [] | |||
| for i, num_blocks in enumerate(stage_blocks): | |||
| _layers = [] | |||
| for n in range(num_blocks): | |||
| if n == 0: | |||
| _layer = conv_dw(stage_planes[i + 1], stage_planes[i + 2], | |||
| 2) | |||
| else: | |||
| _layer = conv_dw(stage_planes[i + 2], stage_planes[i + 2], | |||
| 1) | |||
| _layers.append(_layer) | |||
| _block = nn.Sequential(*_layers) | |||
| layer_name = f'layer{i + 1}' | |||
| self.add_module(layer_name, _block) | |||
| self.stage_layers.append(layer_name) | |||
| def forward(self, x): | |||
| output = [] | |||
| x = self.stem(x) | |||
| for i, layer_name in enumerate(self.stage_layers): | |||
| stage_layer = getattr(self, layer_name) | |||
| x = stage_layer(x) | |||
| if i in self.out_indices: | |||
| output.append(x) | |||
| return tuple(output) | |||
| def init_weights(self, pretrained=None): | |||
| """Initialize the weights in backbone. | |||
| Args: | |||
| pretrained (str, optional): Path to pre-trained weights. | |||
| Defaults to None. | |||
| """ | |||
| if isinstance(pretrained, str): | |||
| logger = get_root_logger() | |||
| load_checkpoint(self, pretrained, strict=False, logger=logger) | |||
| elif pretrained is None: | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| kaiming_init(m) | |||
| elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |||
| constant_init(m, 1) | |||
| else: | |||
| raise TypeError('pretrained must be a str or None') | |||
| @@ -3,5 +3,6 @@ The implementation here is modified based on insightface, originally MIT license | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors | |||
| """ | |||
| from .scrfd import SCRFD | |||
| from .tinymog import TinyMog | |||
| __all__ = ['SCRFD'] | |||
| __all__ = ['SCRFD', 'TinyMog'] | |||
| @@ -0,0 +1,148 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py | |||
| """ | |||
| import torch | |||
| from mmdet.models.builder import DETECTORS | |||
| from mmdet.models.detectors.single_stage import SingleStageDetector | |||
| from ....mmdet_patch.core.bbox import bbox2result | |||
| @DETECTORS.register_module() | |||
| class TinyMog(SingleStageDetector): | |||
| def __init__(self, | |||
| backbone, | |||
| neck, | |||
| bbox_head, | |||
| train_cfg=None, | |||
| test_cfg=None, | |||
| pretrained=None): | |||
| super(TinyMog, self).__init__(backbone, neck, bbox_head, train_cfg, | |||
| test_cfg, pretrained) | |||
| def forward_train(self, | |||
| img, | |||
| img_metas, | |||
| gt_bboxes, | |||
| gt_labels, | |||
| gt_keypointss=None, | |||
| gt_bboxes_ignore=None): | |||
| """ | |||
| Args: | |||
| img (Tensor): Input images of shape (N, C, H, W). | |||
| Typically these should be mean centered and std scaled. | |||
| img_metas (list[dict]): A List of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| :class:`mmdet.datasets.pipelines.Collect`. | |||
| gt_bboxes (list[Tensor]): Each item are the truth boxes for each | |||
| image in [tl_x, tl_y, br_x, br_y] format. | |||
| gt_labels (list[Tensor]): Class indices corresponding to each box | |||
| gt_bboxes_ignore (None | list[Tensor]): Specify which bounding | |||
| boxes can be ignored when computing the loss. | |||
| Returns: | |||
| dict[str, Tensor]: A dictionary of loss components. | |||
| """ | |||
| super(SingleStageDetector, self).forward_train(img, img_metas) | |||
| x = self.extract_feat(img) | |||
| losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, | |||
| gt_labels, gt_keypointss, | |||
| gt_bboxes_ignore) | |||
| return losses | |||
| def simple_test(self, | |||
| img, | |||
| img_metas, | |||
| rescale=False, | |||
| repeat_head=1, | |||
| output_kps_var=0, | |||
| output_results=1): | |||
| """Test function without test time augmentation. | |||
| Args: | |||
| imgs (list[torch.Tensor]): List of multiple images | |||
| img_metas (list[dict]): List of image information. | |||
| rescale (bool, optional): Whether to rescale the results. | |||
| Defaults to False. | |||
| repeat_head (int): repeat inference times in head | |||
| output_kps_var (int): whether output kps var to calculate quality | |||
| output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||
| Returns: | |||
| list[list[np.ndarray]]: BBox results of each image and classes. | |||
| The outer list corresponds to each image. The inner list | |||
| corresponds to each class. | |||
| """ | |||
| x = self.extract_feat(img) | |||
| assert repeat_head >= 1 | |||
| kps_out0 = [] | |||
| kps_out1 = [] | |||
| kps_out2 = [] | |||
| for i in range(repeat_head): | |||
| outs = self.bbox_head(x) | |||
| kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||
| kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||
| kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||
| if output_kps_var: | |||
| var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||
| var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||
| var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||
| var = np.mean([var0, var1, var2]) | |||
| else: | |||
| var = None | |||
| if output_results > 0: | |||
| if torch.onnx.is_in_onnx_export(): | |||
| cls_score, bbox_pred, kps_pred = outs | |||
| for c in cls_score: | |||
| print(c.shape) | |||
| for c in bbox_pred: | |||
| print(c.shape) | |||
| if self.bbox_head.use_kps: | |||
| for c in kps_pred: | |||
| print(c.shape) | |||
| return (cls_score, bbox_pred, kps_pred) | |||
| else: | |||
| return (cls_score, bbox_pred) | |||
| bbox_list = self.bbox_head.get_bboxes( | |||
| *outs, img_metas, rescale=rescale) | |||
| # return kps if use_kps | |||
| if len(bbox_list[0]) == 2: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, | |||
| self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels in bbox_list | |||
| ] | |||
| elif len(bbox_list[0]) == 3: | |||
| if output_results == 2: | |||
| bbox_results = [ | |||
| bbox2result( | |||
| det_bboxes, | |||
| det_labels, | |||
| self.bbox_head.num_classes, | |||
| kps=det_kps, | |||
| num_kps=self.bbox_head.NK) | |||
| for det_bboxes, det_labels, det_kps in bbox_list | |||
| ] | |||
| elif output_results == 1: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, | |||
| self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels, _ in bbox_list | |||
| ] | |||
| else: | |||
| bbox_results = None | |||
| if var is not None: | |||
| return bbox_results, var | |||
| else: | |||
| return bbox_results | |||
| def feature_test(self, img): | |||
| x = self.extract_feat(img) | |||
| outs = self.bbox_head(x) | |||
| return outs | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from copy import deepcopy | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['TinyMogDetect'] | |||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.tinymog) | |||
| class TinyMogDetect(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| """ | |||
| initialize the tinymog face detection model from the `model_dir` path. | |||
| """ | |||
| super().__init__(model_dir) | |||
| from mmcv import Config | |||
| from mmcv.parallel import MMDataParallel | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.models import build_detector | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||
| cfg = Config.fromfile(osp.join(model_dir, 'mmcv_tinymog.py')) | |||
| ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||
| detector = build_detector(cfg.model) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| load_checkpoint(detector, ckpt_path, map_location='cpu') | |||
| detector = MMDataParallel(detector) | |||
| detector.eval() | |||
| self.detector = detector | |||
| logger.info('load model done') | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = self.detector( | |||
| return_loss=False, | |||
| rescale=True, | |||
| img=[input['img'][0].unsqueeze(0)], | |||
| img_metas=[[dict(input['img_metas'][0].data)]], | |||
| output_results=2) | |||
| assert result is not None | |||
| result = result[0][0] | |||
| bboxes = result[:, :4].tolist() | |||
| kpss = result[:, 5:].tolist() | |||
| scores = result[:, 4].tolist() | |||
| return { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.BOXES: bboxes, | |||
| OutputKeys.KEYPOINTS: kpss | |||
| } | |||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| return input | |||
| @@ -8,11 +8,12 @@ import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_detection import ScrfdDetect | |||
| from modelscope.models.cv.face_detection import ScrfdDetect, TinyMogDetect | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -30,7 +31,13 @@ class FaceDetectionPipeline(Pipeline): | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| detector = ScrfdDetect(model_dir=model, **kwargs) | |||
| config_path = osp.join(model, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| cfg_model = getattr(cfg, 'model', None) | |||
| if cfg_model is None: | |||
| detector = ScrfdDetect(model_dir=model, **kwargs) | |||
| elif cfg_model.type == 'tinymog': | |||
| detector = self.model.to(self.device) | |||
| self.detector = detector | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import draw_face_detection_result | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class TinyMogFaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.face_detection | |||
| self.model_id = 'damo/cv_manual_face-detection_tinymog' | |||
| self.img_path = 'data/test/images/mog_face_detection.jpg' | |||
| def show_result(self, img_path, detection_result): | |||
| img = draw_face_detection_result(img_path, detection_result) | |||
| cv2.imwrite('result.png', img) | |||
| print(f'output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_dataset(self): | |||
| input_location = ['data/test/images/mog_face_detection.jpg'] | |||
| dataset = MsDataset.load(input_location, target='image') | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||
| result = face_detection(dataset) | |||
| result = next(result) | |||
| self.show_result(input_location[0], result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| result = face_detection(self.img_path) | |||
| self.show_result(self.img_path, result) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| face_detection = pipeline(Tasks.face_detection) | |||
| result = face_detection(self.img_path) | |||
| self.show_result(self.img_path, result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||