Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10726376master^2
| @@ -40,6 +40,7 @@ class Models(object): | |||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | ||||
| fer = 'fer' | fer = 'fer' | ||||
| fairface = 'fairface' | |||||
| retinaface = 'retinaface' | retinaface = 'retinaface' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| mogface = 'mogface' | mogface = 'mogface' | ||||
| @@ -185,6 +186,7 @@ class Pipelines(object): | |||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| tinymog_face_detection = 'manual-face-detection-tinymog' | tinymog_face_detection = 'manual-face-detection-tinymog' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| face_attribute_recognition = 'resnet34-face-attribute-recognition-fairface' | |||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | ||||
| mtcnn_face_detection = 'manual-face-detection-mtcnn' | mtcnn_face_detection = 'manual-face-detection-mtcnn' | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .fair_face import FaceAttributeRecognition | |||||
| else: | |||||
| _import_structure = {'fair_face': ['FaceAttributeRecognition']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .face_attribute_recognition import FaceAttributeRecognition | |||||
| @@ -0,0 +1,79 @@ | |||||
| # The implementation is based on FairFace, available at | |||||
| # https://github.com/dchen236/FairFace | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision | |||||
| from PIL import Image | |||||
| from torch.autograd import Variable | |||||
| from torchvision import datasets, models, transforms | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| @MODELS.register_module( | |||||
| Tasks.face_attribute_recognition, module_name=Models.fairface) | |||||
| class FaceAttributeRecognition(TorchModel): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| super().__init__(model_path) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.device = device | |||||
| self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||||
| ModelFile.CONFIGURATION) | |||||
| fair_face = torchvision.models.resnet34(pretrained=False) | |||||
| fair_face.fc = nn.Linear(fair_face.fc.in_features, 18) | |||||
| self.net = fair_face | |||||
| self.load_model() | |||||
| self.net = self.net.to(device) | |||||
| self.trans = transforms.Compose([ | |||||
| transforms.ToPILImage(), | |||||
| transforms.Resize((224, 224)), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize( | |||||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
| ]) | |||||
| def load_model(self, load_to_cpu=False): | |||||
| pretrained_dict = torch.load( | |||||
| self.model_path, map_location=torch.device('cpu')) | |||||
| self.net.load_state_dict(pretrained_dict, strict=True) | |||||
| self.net.eval() | |||||
| def forward(self, img): | |||||
| """ FariFace model forward process. | |||||
| Args: | |||||
| img: [h, w, c] | |||||
| Return: | |||||
| list of attribute result: [gender_score, age_score] | |||||
| """ | |||||
| img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2RGB) | |||||
| img = img.astype(np.uint8) | |||||
| inputs = self.trans(img) | |||||
| c, h, w = inputs.shape | |||||
| inputs = inputs.view(-1, c, h, w) | |||||
| inputs = inputs.to(self.device) | |||||
| inputs = Variable(inputs, volatile=True) | |||||
| outputs = self.net(inputs)[0] | |||||
| gender_outputs = outputs[7:9] | |||||
| age_outputs = outputs[9:18] | |||||
| gender_score = F.softmax(gender_outputs).detach().cpu().tolist() | |||||
| age_score = F.softmax(age_outputs).detach().cpu().tolist() | |||||
| return [gender_score, age_score] | |||||
| @@ -137,6 +137,13 @@ TASK_OUTPUTS = { | |||||
| Tasks.facial_expression_recognition: | Tasks.facial_expression_recognition: | ||||
| [OutputKeys.SCORES, OutputKeys.LABELS], | [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
| # face attribute recognition result for single sample | |||||
| # { | |||||
| # "scores": [[0.9, 0.1], [0.92, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] | |||||
| # "labels": [['Male', 'Female'], [0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+]] | |||||
| # } | |||||
| Tasks.face_attribute_recognition: [OutputKeys.SCORES, OutputKeys.LABELS], | |||||
| # face recognition result for single sample | # face recognition result for single sample | ||||
| # { | # { | ||||
| # "img_embedding": np.array with shape [1, D], | # "img_embedding": np.array with shape [1, D], | ||||
| @@ -61,6 +61,8 @@ TASK_INPUTS = { | |||||
| InputType.IMAGE, | InputType.IMAGE, | ||||
| Tasks.facial_expression_recognition: | Tasks.facial_expression_recognition: | ||||
| InputType.IMAGE, | InputType.IMAGE, | ||||
| Tasks.face_attribute_recognition: | |||||
| InputType.IMAGE, | |||||
| Tasks.face_recognition: | Tasks.face_recognition: | ||||
| InputType.IMAGE, | InputType.IMAGE, | ||||
| Tasks.human_detection: | Tasks.human_detection: | ||||
| @@ -135,6 +135,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.facial_expression_recognition: | Tasks.facial_expression_recognition: | ||||
| (Pipelines.facial_expression_recognition, | (Pipelines.facial_expression_recognition, | ||||
| 'damo/cv_vgg19_facial-expression-recognition_fer'), | 'damo/cv_vgg19_facial-expression-recognition_fer'), | ||||
| Tasks.face_attribute_recognition: | |||||
| (Pipelines.face_attribute_recognition, | |||||
| 'damo/cv_resnet34_face-attribute-recognition_fairface'), | |||||
| Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | ||||
| 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | ||||
| Tasks.video_multi_modal_embedding: | Tasks.video_multi_modal_embedding: | ||||
| @@ -59,6 +59,7 @@ if TYPE_CHECKING: | |||||
| from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | ||||
| from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | ||||
| from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | ||||
| from .face_attribute_recognition_pipeline import FaceAttributeRecognitionPipeline | |||||
| from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | ||||
| from .hand_static_pipeline import HandStaticPipeline | from .hand_static_pipeline import HandStaticPipeline | ||||
| from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline | from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline | ||||
| @@ -132,8 +133,11 @@ else: | |||||
| 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], | 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], | ||||
| 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | ||||
| 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | ||||
| 'facial_expression_recognition_pipelin': | |||||
| 'facial_expression_recognition_pipeline': | |||||
| ['FacialExpressionRecognitionPipeline'], | ['FacialExpressionRecognitionPipeline'], | ||||
| 'face_attribute_recognition_pipeline': [ | |||||
| 'FaceAttributeRecognitionPipeline' | |||||
| ], | |||||
| 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | ||||
| 'hand_static_pipeline': ['HandStaticPipeline'], | 'hand_static_pipeline': ['HandStaticPipeline'], | ||||
| 'referring_video_object_segmentation_pipeline': [ | 'referring_video_object_segmentation_pipeline': [ | ||||
| @@ -0,0 +1,131 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.face_attribute_recognition import \ | |||||
| FaceAttributeRecognition | |||||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.face_attribute_recognition, | |||||
| module_name=Pipelines.face_attribute_recognition) | |||||
| class FaceAttributeRecognitionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a face attribute recognition pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| device = torch.device( | |||||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||||
| fairface = FaceAttributeRecognition( | |||||
| model_path=ckpt_path, device=device) | |||||
| self.fairface = fairface | |||||
| self.device = device | |||||
| logger.info('load model done') | |||||
| # face detect pipeline | |||||
| det_model_id = 'damo/cv_resnet50_face-detection_retinaface' | |||||
| male_list = ['Male', 'Female'] | |||||
| age_list = [ | |||||
| '0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', | |||||
| '70+' | |||||
| ] | |||||
| self.map_list = [male_list, age_list] | |||||
| self.face_detection = pipeline( | |||||
| Tasks.face_detection, model=det_model_id) | |||||
| def _choose_face(self, | |||||
| det_result, | |||||
| min_face=10, | |||||
| top_face=1, | |||||
| center_face=False): | |||||
| ''' | |||||
| choose face with maximum area | |||||
| Args: | |||||
| det_result: output of face detection pipeline | |||||
| min_face: minimum size of valid face w/h | |||||
| top_face: take faces with top max areas | |||||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||||
| ''' | |||||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||||
| if bboxes.shape[0] == 0: | |||||
| logger.info('Warning: No face detected!') | |||||
| return None | |||||
| # face idx with enough size | |||||
| face_idx = [] | |||||
| for i in range(bboxes.shape[0]): | |||||
| box = bboxes[i] | |||||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||||
| face_idx += [i] | |||||
| if len(face_idx) == 0: | |||||
| logger.info( | |||||
| f'Warning: Face size not enough, less than {min_face}x{min_face}!' | |||||
| ) | |||||
| return None | |||||
| bboxes = bboxes[face_idx] | |||||
| landmarks = landmarks[face_idx] | |||||
| # find max faces | |||||
| boxes = np.array(bboxes) | |||||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||||
| sort_idx = np.argsort(area)[-top_face:] | |||||
| # find center face | |||||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||||
| min_dist = float('inf') | |||||
| sel_idx = -1 | |||||
| for _idx in sort_idx: | |||||
| box = boxes[_idx] | |||||
| dist = np.square( | |||||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||||
| if dist < min_dist: | |||||
| min_dist = dist | |||||
| sel_idx = _idx | |||||
| sort_idx = [sel_idx] | |||||
| main_idx = sort_idx[-1] | |||||
| return bboxes[main_idx], landmarks[main_idx] | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = img[:, :, ::-1] | |||||
| det_result = self.face_detection(img.copy()) | |||||
| rtn = self._choose_face(det_result) | |||||
| face_img = None | |||||
| if rtn is not None: | |||||
| _, face_lmks = rtn | |||||
| face_lmks = face_lmks.reshape(5, 2) | |||||
| face_img, _ = align_face(img, (112, 112), face_lmks) | |||||
| face_img = face_img.astype(np.float32) | |||||
| result = {} | |||||
| result['img'] = face_img | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| scores = self.fairface(input['img']) | |||||
| assert scores is not None | |||||
| return {OutputKeys.SCORES: scores, OutputKeys.LABELS: self.map_list} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -25,6 +25,7 @@ class CVTasks(object): | |||||
| card_detection = 'card-detection' | card_detection = 'card-detection' | ||||
| face_recognition = 'face-recognition' | face_recognition = 'face-recognition' | ||||
| facial_expression_recognition = 'facial-expression-recognition' | facial_expression_recognition = 'facial-expression-recognition' | ||||
| face_attribute_recognition = 'face-attribute-recognition' | |||||
| face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
| human_detection = 'human-detection' | human_detection = 'human-detection' | ||||
| human_object_interaction = 'human-object-interaction' | human_object_interaction = 'human-object-interaction' | ||||
| @@ -6,6 +6,9 @@ import numpy as np | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
| from modelscope.utils import logger as logging | |||||
| logger = logging.get_logger(__name__) | |||||
| def numpy_to_cv2img(img_array): | def numpy_to_cv2img(img_array): | ||||
| @@ -195,6 +198,33 @@ def draw_facial_expression_result(img_path, facial_expression_result): | |||||
| return img | return img | ||||
| def draw_face_attribute_result(img_path, face_attribute_result): | |||||
| scores = face_attribute_result[OutputKeys.SCORES] | |||||
| labels = face_attribute_result[OutputKeys.LABELS] | |||||
| label_gender = labels[0][np.argmax(scores[0])] | |||||
| label_age = labels[1][np.argmax(scores[1])] | |||||
| img = cv2.imread(img_path) | |||||
| assert img is not None, f"Can't read img: {img_path}" | |||||
| cv2.putText( | |||||
| img, | |||||
| 'face gender: {}'.format(label_gender), (10, 10), | |||||
| 1, | |||||
| 1.0, (0, 255, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| cv2.putText( | |||||
| img, | |||||
| 'face age interval: {}'.format(label_age), (10, 40), | |||||
| 1, | |||||
| 1.0, (255, 0, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| logger.info('face gender: {}'.format(label_gender)) | |||||
| logger.info('face age interval: {}'.format(label_age)) | |||||
| return img | |||||
| def draw_face_detection_result(img_path, detection_result): | def draw_face_detection_result(img_path, detection_result): | ||||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | bboxes = np.array(detection_result[OutputKeys.BOXES]) | ||||
| kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | ||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_face_attribute_result | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class FaceAttributeRecognitionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_resnet34_face-attribute-recognition_fairface' | |||||
| def show_result(self, img_path, facial_expression_result): | |||||
| img = draw_face_attribute_result(img_path, facial_expression_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| fair_face = pipeline( | |||||
| Tasks.face_attribute_recognition, model=self.model_id) | |||||
| img_path = 'data/test/images/face_recognition_1.png' | |||||
| result = fair_face(img_path) | |||||
| self.show_result(img_path, result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||