Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10897202 * [to #42322933] 新增ArcFace人脸识别模型master^2
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e37106cf024efd1886b870fa45f69905fcea202db8a848debc4ccd359ea3b21c | |||
| size 116248 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:700f7cb3c958fb710d6b863b3c9aa0549f6ab837dfbe3382f8f750f73cec46e3 | |||
| size 116868 | |||
| @@ -45,6 +45,8 @@ class Models(object): | |||
| mogface = 'mogface' | |||
| mtcnn = 'mtcnn' | |||
| ulfd = 'ulfd' | |||
| arcface = 'arcface' | |||
| facemask = 'facemask' | |||
| video_inpainting = 'video-inpainting' | |||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||
| hand_static = 'hand-static' | |||
| @@ -198,6 +200,7 @@ class Pipelines(object): | |||
| realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | |||
| realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | |||
| face_recognition = 'ir101-face-recognition-cfglint' | |||
| mask_face_recognition = 'resnet-face-recognition-facemask' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| image2image_translation = 'image-to-image-translation' | |||
| live_category = 'live-category' | |||
| @@ -0,0 +1,213 @@ | |||
| # The implementation is adopted from InsightFace, made pubicly available under the Apache-2.0 license at | |||
| # https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py | |||
| from collections import namedtuple | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from torch.nn import (AdaptiveAvgPool2d, AvgPool2d, BatchNorm1d, BatchNorm2d, | |||
| Conv2d, Dropout, Dropout2d, Linear, MaxPool2d, Module, | |||
| Parameter, PReLU, ReLU, Sequential, Sigmoid) | |||
| class Flatten(Module): | |||
| def forward(self, input): | |||
| return input.view(input.size(0), -1) | |||
| class SEModule(Module): | |||
| def __init__(self, channels, reduction): | |||
| super(SEModule, self).__init__() | |||
| self.avg_pool = AdaptiveAvgPool2d(1) | |||
| self.fc1 = Conv2d( | |||
| channels, | |||
| channels // reduction, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| self.relu = ReLU(inplace=True) | |||
| self.fc2 = Conv2d( | |||
| channels // reduction, | |||
| channels, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| self.sigmoid = Sigmoid() | |||
| def forward(self, x): | |||
| module_input = x | |||
| x = self.avg_pool(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.sigmoid(x) | |||
| return module_input * x | |||
| class BottleneckIR(Module): | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(BottleneckIR, self).__init__() | |||
| if in_channel == depth: | |||
| self.shortcut_layer = MaxPool2d(1, stride) | |||
| else: | |||
| self.shortcut_layer = Sequential( | |||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
| BatchNorm2d(depth)) | |||
| self.res_layer = Sequential( | |||
| BatchNorm2d(in_channel), | |||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
| PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
| BatchNorm2d(depth)) | |||
| def forward(self, x): | |||
| shortcut = self.shortcut_layer(x) | |||
| res = self.res_layer(x) | |||
| return res + shortcut | |||
| class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||
| '''A named tuple describing a ResNet block.''' | |||
| def get_block(in_channel, depth, num_units, stride=2): | |||
| return [Bottleneck(in_channel, depth, stride) | |||
| ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||
| def get_blocks(num_layers): | |||
| if num_layers == 50: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=4), | |||
| get_block(in_channel=128, depth=256, num_units=14), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 100: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=13), | |||
| get_block(in_channel=128, depth=256, num_units=30), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 152: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=8), | |||
| get_block(in_channel=128, depth=256, num_units=36), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 252: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=6), | |||
| get_block(in_channel=64, depth=128, num_units=21), | |||
| get_block(in_channel=128, depth=256, num_units=66), | |||
| get_block(in_channel=256, depth=512, num_units=6) | |||
| ] | |||
| return blocks | |||
| class IResNet(Module): | |||
| def __init__(self, | |||
| dropout=0, | |||
| num_features=512, | |||
| zero_init_residual=False, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| fp16=False, | |||
| with_wcd=False, | |||
| wrs_M=400, | |||
| wrs_q=0.9): | |||
| super(IResNet, self).__init__() | |||
| num_layers = 252 | |||
| mode = 'ir' | |||
| assert num_layers in [50, 100, 152, | |||
| 252], 'num_layers should be 50,100, or 152' | |||
| assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' | |||
| self.fc_scale = 7 * 7 | |||
| num_features = 512 | |||
| self.fp16 = fp16 | |||
| drop_ratio = 0.0 | |||
| self.with_wcd = with_wcd | |||
| if self.with_wcd: | |||
| self.wrs_M = wrs_M | |||
| self.wrs_q = wrs_q | |||
| blocks = get_blocks(num_layers) | |||
| if mode == 'ir': | |||
| unit_module = BottleneckIR | |||
| self.input_layer = Sequential( | |||
| Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||
| PReLU(64)) | |||
| self.bn2 = nn.BatchNorm2d( | |||
| 512, | |||
| eps=1e-05, | |||
| ) | |||
| self.dropout = nn.Dropout(p=drop_ratio, inplace=True) | |||
| self.fc = nn.Linear(512 * self.fc_scale, num_features) | |||
| self.features = nn.BatchNorm1d(num_features, eps=1e-05) | |||
| nn.init.constant_(self.features.weight, 1.0) | |||
| self.features.weight.requires_grad = False | |||
| modules = [] | |||
| for block in blocks: | |||
| for bottleneck in block: | |||
| modules.append( | |||
| unit_module(bottleneck.in_channel, bottleneck.depth, | |||
| bottleneck.stride)) | |||
| self.body = Sequential(*modules) | |||
| def forward(self, x): | |||
| with torch.cuda.amp.autocast(self.fp16): | |||
| x = self.input_layer(x) | |||
| x = self.body(x) | |||
| x = self.bn2(x) | |||
| if self.with_wcd: | |||
| B = x.size()[0] | |||
| C = x.size()[1] | |||
| x_abs = torch.abs(x) | |||
| score = torch.nn.functional.adaptive_avg_pool2d(x_abs, | |||
| 1).reshape( | |||
| (B, C)) | |||
| r = torch.rand((B, C), device=x.device) | |||
| key = torch.pow(r, 1. / score) | |||
| _, topidx = torch.topk(key, self.wrs_M, dim=1) | |||
| mask = torch.zeros_like(key, dtype=torch.float32) | |||
| mask.scatter_(1, topidx, 1.) | |||
| maskq = torch.rand((B, C), device=x.device) | |||
| maskq_ones = torch.ones_like(maskq, dtype=torch.float32) | |||
| maskq_zeros = torch.zeros_like(maskq, dtype=torch.float32) | |||
| maskq_m = torch.where(maskq < self.wrs_q, maskq_ones, | |||
| maskq_zeros) | |||
| new_mask = mask * maskq_m | |||
| score_sum = torch.sum(score, dim=1, keepdim=True) | |||
| selected_score_sum = torch.sum( | |||
| new_mask * score, dim=1, keepdim=True) | |||
| alpha = score_sum / (selected_score_sum + 1e-6) | |||
| alpha = alpha.reshape((B, 1, 1, 1)) | |||
| new_mask = new_mask.reshape((B, C, 1, 1)) | |||
| x = x * new_mask * alpha | |||
| x = torch.flatten(x, 1) | |||
| x = self.dropout(x) | |||
| x = self.fc(x.float() if self.fp16 else x) | |||
| x = self.features(x) | |||
| return x | |||
| def iresnet286(pretrained=False, progress=True, **kwargs): | |||
| model = IResNet( | |||
| dropout=0, | |||
| num_features=512, | |||
| zero_init_residual=False, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| fp16=False, | |||
| with_wcd=False, | |||
| wrs_M=400, | |||
| wrs_q=0.9) | |||
| return model | |||
| @@ -18,6 +18,7 @@ if TYPE_CHECKING: | |||
| from .face_detection_pipeline import FaceDetectionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||
| from .mask_face_recognition_pipeline import MaskFaceRecognitionPipeline | |||
| from .general_recognition_pipeline import GeneralRecognitionPipeline | |||
| from .image_cartoon_pipeline import ImageCartoonPipeline | |||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
| @@ -79,6 +80,7 @@ else: | |||
| 'face_detection_pipeline': ['FaceDetectionPipeline'], | |||
| 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | |||
| 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | |||
| 'mask_face_recognition_pipeline': ['MaskFaceRecognitionPipeline'], | |||
| 'general_recognition_pipeline': ['GeneralRecognitionPipeline'], | |||
| 'image_classification_pipeline': | |||
| ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], | |||
| @@ -0,0 +1,138 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from collections import OrderedDict | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||
| from modelscope.models.cv.face_recognition.torchkit.backbone.facemask_backbone import \ | |||
| iresnet286 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_recognition, module_name=Pipelines.mask_face_recognition) | |||
| class MaskFaceRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a mask face recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| # face recong model | |||
| super().__init__(model=model, **kwargs) | |||
| face_model = iresnet286() | |||
| state_dict = torch.load(osp.join(model, ModelFile.TORCH_MODEL_FILE)) | |||
| reviesed_state_dict = self._prefix_revision(state_dict) | |||
| face_model.load_state_dict(reviesed_state_dict, strict=True) | |||
| face_model = face_model.to(self.device) | |||
| face_model.eval() | |||
| self.face_model = face_model | |||
| logger.info('face recognition model loaded!') | |||
| # face detect pipeline | |||
| det_model_id = 'damo/cv_resnet50_face-detection_retinaface' | |||
| self.face_detection = pipeline( | |||
| Tasks.face_detection, model=det_model_id) | |||
| def _prefix_revision(self, state_dict): | |||
| new_state_dict = OrderedDict() | |||
| for k, v in state_dict.items(): | |||
| if k.startswith('module.'): | |||
| k = k[7:] | |||
| new_state_dict[k] = v | |||
| state = new_state_dict | |||
| return state | |||
| def _choose_face(self, | |||
| det_result, | |||
| min_face=10, | |||
| top_face=1, | |||
| center_face=False): | |||
| ''' | |||
| choose face with maximum area | |||
| Args: | |||
| det_result: output of face detection pipeline | |||
| min_face: minimum size of valid face w/h | |||
| top_face: take faces with top max areas | |||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||
| ''' | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.info('No face detected!') | |||
| return None | |||
| # face idx with enough size | |||
| face_idx = [] | |||
| for i in range(bboxes.shape[0]): | |||
| box = bboxes[i] | |||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
| face_idx += [i] | |||
| if len(face_idx) == 0: | |||
| logger.info( | |||
| f'Face size not enough, less than {min_face}x{min_face}!') | |||
| return None | |||
| bboxes = bboxes[face_idx] | |||
| landmarks = landmarks[face_idx] | |||
| # find max faces | |||
| boxes = np.array(bboxes) | |||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
| sort_idx = np.argsort(area)[-top_face:] | |||
| # find center face | |||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||
| min_dist = float('inf') | |||
| sel_idx = -1 | |||
| for _idx in sort_idx: | |||
| box = boxes[_idx] | |||
| dist = np.square( | |||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||
| if dist < min_dist: | |||
| min_dist = dist | |||
| sel_idx = _idx | |||
| sort_idx = [sel_idx] | |||
| main_idx = sort_idx[-1] | |||
| return bboxes[main_idx], landmarks[main_idx] | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img[:, :, ::-1] | |||
| det_result = self.face_detection(img.copy()) | |||
| rtn = self._choose_face(det_result) | |||
| face_img = None | |||
| if rtn is not None: | |||
| _, face_lmks = rtn | |||
| face_lmks = face_lmks.reshape(5, 2) | |||
| align_img, _ = align_face(img, (112, 112), face_lmks) | |||
| face_img = align_img[:, :, ::-1] # to rgb | |||
| face_img = np.transpose(face_img, axes=(2, 0, 1)) | |||
| face_img = (face_img / 255. - 0.5) / 0.5 | |||
| face_img = face_img.astype(np.float32) | |||
| result = {} | |||
| result['img'] = face_img | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| assert input['img'] is not None | |||
| img = input['img'].unsqueeze(0) | |||
| emb = self.face_model(img).detach().cpu().numpy() | |||
| emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm | |||
| return {OutputKeys.IMG_EMBEDDING: emb} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class MaskFaceRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.face_recognition | |||
| self.model_id = 'damo/cv_resnet_face-recognition_facemask' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_face_compare(self): | |||
| img1 = 'data/test/images/mask_face_recognition_1.jpg' | |||
| img2 = 'data/test/images/mask_face_recognition_2.jpg' | |||
| face_recognition = pipeline( | |||
| Tasks.face_recognition, model=self.model_id) | |||
| emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] | |||
| emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] | |||
| sim = np.dot(emb1[0], emb2[0]) | |||
| print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||