Browse Source

[to #42322933] Add face mask model

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10897202

    * [to #42322933] 新增ArcFace人脸识别模型
master^2
ly261666 yingda.chen 3 years ago
parent
commit
492aa98d9a
7 changed files with 399 additions and 0 deletions
  1. +3
    -0
      data/test/images/mask_face_recognition_1.jpg
  2. +3
    -0
      data/test/images/mask_face_recognition_2.jpg
  3. +3
    -0
      modelscope/metainfo.py
  4. +213
    -0
      modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py
  5. +2
    -0
      modelscope/pipelines/cv/__init__.py
  6. +138
    -0
      modelscope/pipelines/cv/mask_face_recognition_pipeline.py
  7. +37
    -0
      tests/pipelines/test_mask_face_recognition.py

+ 3
- 0
data/test/images/mask_face_recognition_1.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e37106cf024efd1886b870fa45f69905fcea202db8a848debc4ccd359ea3b21c
size 116248

+ 3
- 0
data/test/images/mask_face_recognition_2.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:700f7cb3c958fb710d6b863b3c9aa0549f6ab837dfbe3382f8f750f73cec46e3
size 116868

+ 3
- 0
modelscope/metainfo.py View File

@@ -45,6 +45,8 @@ class Models(object):
mogface = 'mogface'
mtcnn = 'mtcnn'
ulfd = 'ulfd'
arcface = 'arcface'
facemask = 'facemask'
video_inpainting = 'video-inpainting'
human_wholebody_keypoint = 'human-wholebody-keypoint'
hand_static = 'hand-static'
@@ -198,6 +200,7 @@ class Pipelines(object):
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo'
face_recognition = 'ir101-face-recognition-cfglint'
mask_face_recognition = 'resnet-face-recognition-facemask'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'
live_category = 'live-category'


+ 213
- 0
modelscope/models/cv/face_recognition/torchkit/backbone/facemask_backbone.py View File

@@ -0,0 +1,213 @@
# The implementation is adopted from InsightFace, made pubicly available under the Apache-2.0 license at
# https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py

from collections import namedtuple

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import (AdaptiveAvgPool2d, AvgPool2d, BatchNorm1d, BatchNorm2d,
Conv2d, Dropout, Dropout2d, Linear, MaxPool2d, Module,
Parameter, PReLU, ReLU, Sequential, Sigmoid)


class Flatten(Module):

def forward(self, input):
return input.view(input.size(0), -1)


class SEModule(Module):

def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()

def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x


class BottleneckIR(Module):

def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))

def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut


class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
'''A named tuple describing a ResNet block.'''


def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]


def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 252:
blocks = [
get_block(in_channel=64, depth=64, num_units=6),
get_block(in_channel=64, depth=128, num_units=21),
get_block(in_channel=128, depth=256, num_units=66),
get_block(in_channel=256, depth=512, num_units=6)
]
return blocks


class IResNet(Module):

def __init__(self,
dropout=0,
num_features=512,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
fp16=False,
with_wcd=False,
wrs_M=400,
wrs_q=0.9):
super(IResNet, self).__init__()
num_layers = 252
mode = 'ir'
assert num_layers in [50, 100, 152,
252], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
self.fc_scale = 7 * 7
num_features = 512
self.fp16 = fp16
drop_ratio = 0.0
self.with_wcd = with_wcd
if self.with_wcd:
self.wrs_M = wrs_M
self.wrs_q = wrs_q
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = BottleneckIR
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
self.bn2 = nn.BatchNorm2d(
512,
eps=1e-05,
)
self.dropout = nn.Dropout(p=drop_ratio, inplace=True)
self.fc = nn.Linear(512 * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False

modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)

def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.input_layer(x)
x = self.body(x)
x = self.bn2(x)
if self.with_wcd:
B = x.size()[0]
C = x.size()[1]
x_abs = torch.abs(x)
score = torch.nn.functional.adaptive_avg_pool2d(x_abs,
1).reshape(
(B, C))
r = torch.rand((B, C), device=x.device)
key = torch.pow(r, 1. / score)
_, topidx = torch.topk(key, self.wrs_M, dim=1)
mask = torch.zeros_like(key, dtype=torch.float32)
mask.scatter_(1, topidx, 1.)
maskq = torch.rand((B, C), device=x.device)
maskq_ones = torch.ones_like(maskq, dtype=torch.float32)
maskq_zeros = torch.zeros_like(maskq, dtype=torch.float32)
maskq_m = torch.where(maskq < self.wrs_q, maskq_ones,
maskq_zeros)
new_mask = mask * maskq_m
score_sum = torch.sum(score, dim=1, keepdim=True)
selected_score_sum = torch.sum(
new_mask * score, dim=1, keepdim=True)
alpha = score_sum / (selected_score_sum + 1e-6)
alpha = alpha.reshape((B, 1, 1, 1))
new_mask = new_mask.reshape((B, C, 1, 1))
x = x * new_mask * alpha
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x


def iresnet286(pretrained=False, progress=True, **kwargs):
model = IResNet(
dropout=0,
num_features=512,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
fp16=False,
with_wcd=False,
wrs_M=400,
wrs_q=0.9)
return model

+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -18,6 +18,7 @@ if TYPE_CHECKING:
from .face_detection_pipeline import FaceDetectionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .face_recognition_pipeline import FaceRecognitionPipeline
from .mask_face_recognition_pipeline import MaskFaceRecognitionPipeline
from .general_recognition_pipeline import GeneralRecognitionPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_classification_pipeline import GeneralImageClassificationPipeline
@@ -79,6 +80,7 @@ else:
'face_detection_pipeline': ['FaceDetectionPipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
'mask_face_recognition_pipeline': ['MaskFaceRecognitionPipeline'],
'general_recognition_pipeline': ['GeneralRecognitionPipeline'],
'image_classification_pipeline':
['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'],


+ 138
- 0
modelscope/pipelines/cv/mask_face_recognition_pipeline.py View File

@@ -0,0 +1,138 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from collections import OrderedDict
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_recognition.align_face import align_face
from modelscope.models.cv.face_recognition.torchkit.backbone.facemask_backbone import \
iresnet286
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.face_recognition, module_name=Pipelines.mask_face_recognition)
class MaskFaceRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a mask face recognition pipeline for prediction
Args:
model: model id on modelscope hub.
"""

# face recong model
super().__init__(model=model, **kwargs)
face_model = iresnet286()
state_dict = torch.load(osp.join(model, ModelFile.TORCH_MODEL_FILE))
reviesed_state_dict = self._prefix_revision(state_dict)
face_model.load_state_dict(reviesed_state_dict, strict=True)
face_model = face_model.to(self.device)
face_model.eval()
self.face_model = face_model
logger.info('face recognition model loaded!')
# face detect pipeline
det_model_id = 'damo/cv_resnet50_face-detection_retinaface'
self.face_detection = pipeline(
Tasks.face_detection, model=det_model_id)

def _prefix_revision(self, state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
state = new_state_dict
return state

def _choose_face(self,
det_result,
min_face=10,
top_face=1,
center_face=False):
'''
choose face with maximum area
Args:
det_result: output of face detection pipeline
min_face: minimum size of valid face w/h
top_face: take faces with top max areas
center_face: choose the most centerd face from multi faces, only valid if top_face > 1
'''
bboxes = np.array(det_result[OutputKeys.BOXES])
landmarks = np.array(det_result[OutputKeys.KEYPOINTS])
if bboxes.shape[0] == 0:
logger.info('No face detected!')
return None
# face idx with enough size
face_idx = []
for i in range(bboxes.shape[0]):
box = bboxes[i]
if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face:
face_idx += [i]
if len(face_idx) == 0:
logger.info(
f'Face size not enough, less than {min_face}x{min_face}!')
return None
bboxes = bboxes[face_idx]
landmarks = landmarks[face_idx]
# find max faces
boxes = np.array(bboxes)
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
sort_idx = np.argsort(area)[-top_face:]
# find center face
if top_face > 1 and center_face and bboxes.shape[0] > 1:
img_center = [img.shape[1] // 2, img.shape[0] // 2]
min_dist = float('inf')
sel_idx = -1
for _idx in sort_idx:
box = boxes[_idx]
dist = np.square(
np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square(
np.abs((box[1] + box[3]) / 2 - img_center[1]))
if dist < min_dist:
min_dist = dist
sel_idx = _idx
sort_idx = [sel_idx]
main_idx = sort_idx[-1]
return bboxes[main_idx], landmarks[main_idx]

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img[:, :, ::-1]
det_result = self.face_detection(img.copy())
rtn = self._choose_face(det_result)
face_img = None
if rtn is not None:
_, face_lmks = rtn
face_lmks = face_lmks.reshape(5, 2)
align_img, _ = align_face(img, (112, 112), face_lmks)
face_img = align_img[:, :, ::-1] # to rgb
face_img = np.transpose(face_img, axes=(2, 0, 1))
face_img = (face_img / 255. - 0.5) / 0.5
face_img = face_img.astype(np.float32)
result = {}
result['img'] = face_img
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
assert input['img'] is not None
img = input['img'].unsqueeze(0)
emb = self.face_model(img).detach().cpu().numpy()
emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm
return {OutputKeys.IMG_EMBEDDING: emb}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 37
- 0
tests/pipelines/test_mask_face_recognition.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import numpy as np

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class MaskFaceRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.face_recognition
self.model_id = 'damo/cv_resnet_face-recognition_facemask'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_compare(self):
img1 = 'data/test/images/mask_face_recognition_1.jpg'
img2 = 'data/test/images/mask_face_recognition_2.jpg'

face_recognition = pipeline(
Tasks.face_recognition, model=self.model_id)
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
sim = np.dot(emb1[0], emb2[0])
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save