From 492aa98d9a82e5342405b37e99c2942da3539021 Mon Sep 17 00:00:00 2001 From: ly261666 Date: Sun, 4 Dec 2022 15:25:27 +0800 Subject: [PATCH] [to #42322933] Add face mask model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10897202 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [to #42322933] 新增ArcFace人脸识别模型 --- data/test/images/mask_face_recognition_1.jpg | 3 + data/test/images/mask_face_recognition_2.jpg | 3 + modelscope/metainfo.py | 3 + .../torchkit/backbone/facemask_backbone.py | 213 ++++++++++++++++++ modelscope/pipelines/cv/__init__.py | 2 + .../cv/mask_face_recognition_pipeline.py | 138 ++++++++++++ tests/pipelines/test_mask_face_recognition.py | 37 +++ 7 files changed, 399 insertions(+) create mode 100644 data/test/images/mask_face_recognition_1.jpg create mode 100644 data/test/images/mask_face_recognition_2.jpg create mode 100644 modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py create mode 100644 modelscope/pipelines/cv/mask_face_recognition_pipeline.py create mode 100644 tests/pipelines/test_mask_face_recognition.py diff --git a/data/test/images/mask_face_recognition_1.jpg b/data/test/images/mask_face_recognition_1.jpg new file mode 100644 index 00000000..ffdff3e0 --- /dev/null +++ b/data/test/images/mask_face_recognition_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e37106cf024efd1886b870fa45f69905fcea202db8a848debc4ccd359ea3b21c +size 116248 diff --git a/data/test/images/mask_face_recognition_2.jpg b/data/test/images/mask_face_recognition_2.jpg new file mode 100644 index 00000000..ccc0d238 --- /dev/null +++ b/data/test/images/mask_face_recognition_2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:700f7cb3c958fb710d6b863b3c9aa0549f6ab837dfbe3382f8f750f73cec46e3 +size 116868 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 9ee4091f..12274fb9 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -45,6 +45,8 @@ class Models(object): mogface = 'mogface' mtcnn = 'mtcnn' ulfd = 'ulfd' + arcface = 'arcface' + facemask = 'facemask' video_inpainting = 'video-inpainting' human_wholebody_keypoint = 'human-wholebody-keypoint' hand_static = 'hand-static' @@ -198,6 +200,7 @@ class Pipelines(object): realtime_object_detection = 'cspnet_realtime-object-detection_yolox' realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' face_recognition = 'ir101-face-recognition-cfglint' + mask_face_recognition = 'resnet-face-recognition-facemask' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' image2image_translation = 'image-to-image-translation' live_category = 'live-category' diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py b/modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py new file mode 100644 index 00000000..c9e01367 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py @@ -0,0 +1,213 @@ +# The implementation is adopted from InsightFace, made pubicly available under the Apache-2.0 license at +# https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py + +from collections import namedtuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import (AdaptiveAvgPool2d, AvgPool2d, BatchNorm1d, BatchNorm2d, + Conv2d, Dropout, Dropout2d, Linear, MaxPool2d, Module, + Parameter, PReLU, ReLU, Sequential, Sigmoid) + + +class Flatten(Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +class SEModule(Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class BottleneckIR(Module): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + '''A named tuple describing a ResNet block.''' + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 252: + blocks = [ + get_block(in_channel=64, depth=64, num_units=6), + get_block(in_channel=64, depth=128, num_units=21), + get_block(in_channel=128, depth=256, num_units=66), + get_block(in_channel=256, depth=512, num_units=6) + ] + return blocks + + +class IResNet(Module): + + def __init__(self, + dropout=0, + num_features=512, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + fp16=False, + with_wcd=False, + wrs_M=400, + wrs_q=0.9): + super(IResNet, self).__init__() + num_layers = 252 + mode = 'ir' + assert num_layers in [50, 100, 152, + 252], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + self.fc_scale = 7 * 7 + num_features = 512 + self.fp16 = fp16 + drop_ratio = 0.0 + self.with_wcd = with_wcd + if self.with_wcd: + self.wrs_M = wrs_M + self.wrs_q = wrs_q + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + self.bn2 = nn.BatchNorm2d( + 512, + eps=1e-05, + ) + self.dropout = nn.Dropout(p=drop_ratio, inplace=True) + self.fc = nn.Linear(512 * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.input_layer(x) + x = self.body(x) + x = self.bn2(x) + if self.with_wcd: + B = x.size()[0] + C = x.size()[1] + x_abs = torch.abs(x) + score = torch.nn.functional.adaptive_avg_pool2d(x_abs, + 1).reshape( + (B, C)) + r = torch.rand((B, C), device=x.device) + key = torch.pow(r, 1. / score) + _, topidx = torch.topk(key, self.wrs_M, dim=1) + mask = torch.zeros_like(key, dtype=torch.float32) + mask.scatter_(1, topidx, 1.) + maskq = torch.rand((B, C), device=x.device) + maskq_ones = torch.ones_like(maskq, dtype=torch.float32) + maskq_zeros = torch.zeros_like(maskq, dtype=torch.float32) + maskq_m = torch.where(maskq < self.wrs_q, maskq_ones, + maskq_zeros) + new_mask = mask * maskq_m + score_sum = torch.sum(score, dim=1, keepdim=True) + selected_score_sum = torch.sum( + new_mask * score, dim=1, keepdim=True) + alpha = score_sum / (selected_score_sum + 1e-6) + alpha = alpha.reshape((B, 1, 1, 1)) + new_mask = new_mask.reshape((B, C, 1, 1)) + x = x * new_mask * alpha + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def iresnet286(pretrained=False, progress=True, **kwargs): + model = IResNet( + dropout=0, + num_features=512, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + fp16=False, + with_wcd=False, + wrs_M=400, + wrs_q=0.9) + return model diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index e5bebe5f..75de5805 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from .face_detection_pipeline import FaceDetectionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline from .face_recognition_pipeline import FaceRecognitionPipeline + from .mask_face_recognition_pipeline import MaskFaceRecognitionPipeline from .general_recognition_pipeline import GeneralRecognitionPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_classification_pipeline import GeneralImageClassificationPipeline @@ -79,6 +80,7 @@ else: 'face_detection_pipeline': ['FaceDetectionPipeline'], 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], 'face_recognition_pipeline': ['FaceRecognitionPipeline'], + 'mask_face_recognition_pipeline': ['MaskFaceRecognitionPipeline'], 'general_recognition_pipeline': ['GeneralRecognitionPipeline'], 'image_classification_pipeline': ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], diff --git a/modelscope/pipelines/cv/mask_face_recognition_pipeline.py b/modelscope/pipelines/cv/mask_face_recognition_pipeline.py new file mode 100644 index 00000000..2190b6d0 --- /dev/null +++ b/modelscope/pipelines/cv/mask_face_recognition_pipeline.py @@ -0,0 +1,138 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from collections import OrderedDict +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_recognition.align_face import align_face +from modelscope.models.cv.face_recognition.torchkit.backbone.facemask_backbone import \ + iresnet286 +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_recognition, module_name=Pipelines.mask_face_recognition) +class MaskFaceRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a mask face recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + # face recong model + super().__init__(model=model, **kwargs) + face_model = iresnet286() + state_dict = torch.load(osp.join(model, ModelFile.TORCH_MODEL_FILE)) + reviesed_state_dict = self._prefix_revision(state_dict) + face_model.load_state_dict(reviesed_state_dict, strict=True) + face_model = face_model.to(self.device) + face_model.eval() + self.face_model = face_model + logger.info('face recognition model loaded!') + # face detect pipeline + det_model_id = 'damo/cv_resnet50_face-detection_retinaface' + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) + + def _prefix_revision(self, state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + new_state_dict[k] = v + state = new_state_dict + return state + + def _choose_face(self, + det_result, + min_face=10, + top_face=1, + center_face=False): + ''' + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + top_face: take faces with top max areas + center_face: choose the most centerd face from multi faces, only valid if top_face > 1 + ''' + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + if bboxes.shape[0] == 0: + logger.info('No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.info( + f'Face size not enough, less than {min_face}x{min_face}!') + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + # find max faces + boxes = np.array(bboxes) + area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + sort_idx = np.argsort(area)[-top_face:] + # find center face + if top_face > 1 and center_face and bboxes.shape[0] > 1: + img_center = [img.shape[1] // 2, img.shape[0] // 2] + min_dist = float('inf') + sel_idx = -1 + for _idx in sort_idx: + box = boxes[_idx] + dist = np.square( + np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( + np.abs((box[1] + box[3]) / 2 - img_center[1])) + if dist < min_dist: + min_dist = dist + sel_idx = _idx + sort_idx = [sel_idx] + main_idx = sort_idx[-1] + return bboxes[main_idx], landmarks[main_idx] + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img[:, :, ::-1] + det_result = self.face_detection(img.copy()) + rtn = self._choose_face(det_result) + face_img = None + if rtn is not None: + _, face_lmks = rtn + face_lmks = face_lmks.reshape(5, 2) + align_img, _ = align_face(img, (112, 112), face_lmks) + face_img = align_img[:, :, ::-1] # to rgb + face_img = np.transpose(face_img, axes=(2, 0, 1)) + face_img = (face_img / 255. - 0.5) / 0.5 + face_img = face_img.astype(np.float32) + result = {} + result['img'] = face_img + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + assert input['img'] is not None + img = input['img'].unsqueeze(0) + emb = self.face_model(img).detach().cpu().numpy() + emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm + return {OutputKeys.IMG_EMBEDDING: emb} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_mask_face_recognition.py b/tests/pipelines/test_mask_face_recognition.py new file mode 100644 index 00000000..550e80e4 --- /dev/null +++ b/tests/pipelines/test_mask_face_recognition.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MaskFaceRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.face_recognition + self.model_id = 'damo/cv_resnet_face-recognition_facemask' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_face_compare(self): + img1 = 'data/test/images/mask_face_recognition_1.jpg' + img2 = 'data/test/images/mask_face_recognition_2.jpg' + + face_recognition = pipeline( + Tasks.face_recognition, model=self.model_id) + emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] + emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] + sim = np.dot(emb1[0], emb2[0]) + print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()