Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10726376master^2
| @@ -40,6 +40,7 @@ class Models(object): | |||
| resnet50_bert = 'resnet50-bert' | |||
| referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||
| fer = 'fer' | |||
| fairface = 'fairface' | |||
| retinaface = 'retinaface' | |||
| shop_segmentation = 'shop-segmentation' | |||
| mogface = 'mogface' | |||
| @@ -185,6 +186,7 @@ class Pipelines(object): | |||
| ulfd_face_detection = 'manual-face-detection-ulfd' | |||
| tinymog_face_detection = 'manual-face-detection-tinymog' | |||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
| face_attribute_recognition = 'resnet34-face-attribute-recognition-fairface' | |||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||
| mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' | |||
| mtcnn_face_detection = 'manual-face-detection-mtcnn' | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .fair_face import FaceAttributeRecognition | |||
| else: | |||
| _import_structure = {'fair_face': ['FaceAttributeRecognition']} | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .face_attribute_recognition import FaceAttributeRecognition | |||
| @@ -0,0 +1,79 @@ | |||
| # The implementation is based on FairFace, available at | |||
| # https://github.com/dchen236/FairFace | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.backends.cudnn as cudnn | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchvision | |||
| from PIL import Image | |||
| from torch.autograd import Variable | |||
| from torchvision import datasets, models, transforms | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @MODELS.register_module( | |||
| Tasks.face_attribute_recognition, module_name=Models.fairface) | |||
| class FaceAttributeRecognition(TorchModel): | |||
| def __init__(self, model_path, device='cuda'): | |||
| super().__init__(model_path) | |||
| cudnn.benchmark = True | |||
| self.model_path = model_path | |||
| self.device = device | |||
| self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||
| ModelFile.CONFIGURATION) | |||
| fair_face = torchvision.models.resnet34(pretrained=False) | |||
| fair_face.fc = nn.Linear(fair_face.fc.in_features, 18) | |||
| self.net = fair_face | |||
| self.load_model() | |||
| self.net = self.net.to(device) | |||
| self.trans = transforms.Compose([ | |||
| transforms.ToPILImage(), | |||
| transforms.Resize((224, 224)), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize( | |||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
| ]) | |||
| def load_model(self, load_to_cpu=False): | |||
| pretrained_dict = torch.load( | |||
| self.model_path, map_location=torch.device('cpu')) | |||
| self.net.load_state_dict(pretrained_dict, strict=True) | |||
| self.net.eval() | |||
| def forward(self, img): | |||
| """ FariFace model forward process. | |||
| Args: | |||
| img: [h, w, c] | |||
| Return: | |||
| list of attribute result: [gender_score, age_score] | |||
| """ | |||
| img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2RGB) | |||
| img = img.astype(np.uint8) | |||
| inputs = self.trans(img) | |||
| c, h, w = inputs.shape | |||
| inputs = inputs.view(-1, c, h, w) | |||
| inputs = inputs.to(self.device) | |||
| inputs = Variable(inputs, volatile=True) | |||
| outputs = self.net(inputs)[0] | |||
| gender_outputs = outputs[7:9] | |||
| age_outputs = outputs[9:18] | |||
| gender_score = F.softmax(gender_outputs).detach().cpu().tolist() | |||
| age_score = F.softmax(age_outputs).detach().cpu().tolist() | |||
| return [gender_score, age_score] | |||
| @@ -137,6 +137,13 @@ TASK_OUTPUTS = { | |||
| Tasks.facial_expression_recognition: | |||
| [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # face attribute recognition result for single sample | |||
| # { | |||
| # "scores": [[0.9, 0.1], [0.92, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01] | |||
| # "labels": [['Male', 'Female'], [0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+]] | |||
| # } | |||
| Tasks.face_attribute_recognition: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # face recognition result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| @@ -61,6 +61,8 @@ TASK_INPUTS = { | |||
| InputType.IMAGE, | |||
| Tasks.facial_expression_recognition: | |||
| InputType.IMAGE, | |||
| Tasks.face_attribute_recognition: | |||
| InputType.IMAGE, | |||
| Tasks.face_recognition: | |||
| InputType.IMAGE, | |||
| Tasks.human_detection: | |||
| @@ -135,6 +135,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.facial_expression_recognition: | |||
| (Pipelines.facial_expression_recognition, | |||
| 'damo/cv_vgg19_facial-expression-recognition_fer'), | |||
| Tasks.face_attribute_recognition: | |||
| (Pipelines.face_attribute_recognition, | |||
| 'damo/cv_resnet34_face-attribute-recognition_fairface'), | |||
| Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | |||
| 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | |||
| Tasks.video_multi_modal_embedding: | |||
| @@ -59,6 +59,7 @@ if TYPE_CHECKING: | |||
| from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | |||
| from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | |||
| from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | |||
| from .face_attribute_recognition_pipeline import FaceAttributeRecognitionPipeline | |||
| from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | |||
| from .hand_static_pipeline import HandStaticPipeline | |||
| from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline | |||
| @@ -132,8 +133,11 @@ else: | |||
| 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], | |||
| 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], | |||
| 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], | |||
| 'facial_expression_recognition_pipelin': | |||
| 'facial_expression_recognition_pipeline': | |||
| ['FacialExpressionRecognitionPipeline'], | |||
| 'face_attribute_recognition_pipeline': [ | |||
| 'FaceAttributeRecognitionPipeline' | |||
| ], | |||
| 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | |||
| 'hand_static_pipeline': ['HandStaticPipeline'], | |||
| 'referring_video_object_segmentation_pipeline': [ | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_attribute_recognition import \ | |||
| FaceAttributeRecognition | |||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.face_attribute_recognition, | |||
| module_name=Pipelines.face_attribute_recognition) | |||
| class FaceAttributeRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a face attribute recognition pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| device = torch.device( | |||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
| fairface = FaceAttributeRecognition( | |||
| model_path=ckpt_path, device=device) | |||
| self.fairface = fairface | |||
| self.device = device | |||
| logger.info('load model done') | |||
| # face detect pipeline | |||
| det_model_id = 'damo/cv_resnet50_face-detection_retinaface' | |||
| male_list = ['Male', 'Female'] | |||
| age_list = [ | |||
| '0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', | |||
| '70+' | |||
| ] | |||
| self.map_list = [male_list, age_list] | |||
| self.face_detection = pipeline( | |||
| Tasks.face_detection, model=det_model_id) | |||
| def _choose_face(self, | |||
| det_result, | |||
| min_face=10, | |||
| top_face=1, | |||
| center_face=False): | |||
| ''' | |||
| choose face with maximum area | |||
| Args: | |||
| det_result: output of face detection pipeline | |||
| min_face: minimum size of valid face w/h | |||
| top_face: take faces with top max areas | |||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||
| ''' | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.info('Warning: No face detected!') | |||
| return None | |||
| # face idx with enough size | |||
| face_idx = [] | |||
| for i in range(bboxes.shape[0]): | |||
| box = bboxes[i] | |||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
| face_idx += [i] | |||
| if len(face_idx) == 0: | |||
| logger.info( | |||
| f'Warning: Face size not enough, less than {min_face}x{min_face}!' | |||
| ) | |||
| return None | |||
| bboxes = bboxes[face_idx] | |||
| landmarks = landmarks[face_idx] | |||
| # find max faces | |||
| boxes = np.array(bboxes) | |||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
| sort_idx = np.argsort(area)[-top_face:] | |||
| # find center face | |||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||
| min_dist = float('inf') | |||
| sel_idx = -1 | |||
| for _idx in sort_idx: | |||
| box = boxes[_idx] | |||
| dist = np.square( | |||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||
| if dist < min_dist: | |||
| min_dist = dist | |||
| sel_idx = _idx | |||
| sort_idx = [sel_idx] | |||
| main_idx = sort_idx[-1] | |||
| return bboxes[main_idx], landmarks[main_idx] | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img[:, :, ::-1] | |||
| det_result = self.face_detection(img.copy()) | |||
| rtn = self._choose_face(det_result) | |||
| face_img = None | |||
| if rtn is not None: | |||
| _, face_lmks = rtn | |||
| face_lmks = face_lmks.reshape(5, 2) | |||
| face_img, _ = align_face(img, (112, 112), face_lmks) | |||
| face_img = face_img.astype(np.float32) | |||
| result = {} | |||
| result['img'] = face_img | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| scores = self.fairface(input['img']) | |||
| assert scores is not None | |||
| return {OutputKeys.SCORES: scores, OutputKeys.LABELS: self.map_list} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -25,6 +25,7 @@ class CVTasks(object): | |||
| card_detection = 'card-detection' | |||
| face_recognition = 'face-recognition' | |||
| facial_expression_recognition = 'facial-expression-recognition' | |||
| face_attribute_recognition = 'face-attribute-recognition' | |||
| face_2d_keypoints = 'face-2d-keypoints' | |||
| human_detection = 'human-detection' | |||
| human_object_interaction = 'human-object-interaction' | |||
| @@ -6,6 +6,9 @@ import numpy as np | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils import logger as logging | |||
| logger = logging.get_logger(__name__) | |||
| def numpy_to_cv2img(img_array): | |||
| @@ -195,6 +198,33 @@ def draw_facial_expression_result(img_path, facial_expression_result): | |||
| return img | |||
| def draw_face_attribute_result(img_path, face_attribute_result): | |||
| scores = face_attribute_result[OutputKeys.SCORES] | |||
| labels = face_attribute_result[OutputKeys.LABELS] | |||
| label_gender = labels[0][np.argmax(scores[0])] | |||
| label_age = labels[1][np.argmax(scores[1])] | |||
| img = cv2.imread(img_path) | |||
| assert img is not None, f"Can't read img: {img_path}" | |||
| cv2.putText( | |||
| img, | |||
| 'face gender: {}'.format(label_gender), (10, 10), | |||
| 1, | |||
| 1.0, (0, 255, 0), | |||
| thickness=1, | |||
| lineType=8) | |||
| cv2.putText( | |||
| img, | |||
| 'face age interval: {}'.format(label_age), (10, 40), | |||
| 1, | |||
| 1.0, (255, 0, 0), | |||
| thickness=1, | |||
| lineType=8) | |||
| logger.info('face gender: {}'.format(label_gender)) | |||
| logger.info('face age interval: {}'.format(label_age)) | |||
| return img | |||
| def draw_face_detection_result(img_path, detection_result): | |||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
| kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import draw_face_attribute_result | |||
| from modelscope.utils.test_utils import test_level | |||
| class FaceAttributeRecognitionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_resnet34_face-attribute-recognition_fairface' | |||
| def show_result(self, img_path, facial_expression_result): | |||
| img = draw_face_attribute_result(img_path, facial_expression_result) | |||
| cv2.imwrite('result.png', img) | |||
| print(f'output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| fair_face = pipeline( | |||
| Tasks.face_attribute_recognition, model=self.model_id) | |||
| img_path = 'data/test/images/face_recognition_1.png' | |||
| result = fair_face(img_path) | |||
| self.show_result(img_path, result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||