完成Maas-cv CR自查; 新增个Task,已经跟产品确认可以增加,正在走流程中,目前还不在https://aone.alibaba-inc.com/v2/project/1181559/req#viewIdentifier=d7f112f9d023e2108fa1b0d8这里,后续会增加过来 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9976346master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:bdb1cef5a5fd5f938a856311011c4820ddc45946a470b9929c61e59b6a065633 | |||||
| size 161535 | |||||
| @@ -32,6 +32,7 @@ class Models(object): | |||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| fer = 'fer' | |||||
| retinaface = 'retinaface' | retinaface = 'retinaface' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| @@ -119,6 +120,7 @@ class Pipelines(object): | |||||
| salient_detection = 'u2net-salient-detection' | salient_detection = 'u2net-salient-detection' | ||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| @@ -0,0 +1,72 @@ | |||||
| # The implementation is based on Facial-Expression-Recognition, available at | |||||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| import torch.nn.functional as F | |||||
| from PIL import Image | |||||
| from torch.autograd import Variable | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from . import transforms | |||||
| from .vgg import VGG | |||||
| @MODELS.register_module( | |||||
| Tasks.facial_expression_recognition, module_name=Models.fer) | |||||
| class FacialExpressionRecognition(TorchModel): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| super().__init__(model_path) | |||||
| torch.set_grad_enabled(False) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.device = device | |||||
| self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||||
| ModelFile.CONFIGURATION) | |||||
| self.net = VGG('VGG19', cfg_path=self.cfg_path) | |||||
| self.load_model() | |||||
| self.net = self.net.to(device) | |||||
| self.transform_test = transforms.Compose([ | |||||
| transforms.TenCrop(44), | |||||
| transforms.Lambda(lambda crops: torch.stack( | |||||
| [transforms.ToTensor()(crop) for crop in crops])), | |||||
| ]) | |||||
| self.mean = np.array([[104, 117, 123]]) | |||||
| def load_model(self, load_to_cpu=False): | |||||
| pretrained_dict = torch.load( | |||||
| self.model_path, map_location=torch.device('cpu')) | |||||
| self.net.load_state_dict(pretrained_dict['net'], strict=True) | |||||
| self.net.eval() | |||||
| def forward(self, input): | |||||
| img = input['img'] | |||||
| img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2GRAY) | |||||
| img = cv2.resize(img, (48, 48)) | |||||
| img = img[:, :, np.newaxis] | |||||
| img = np.concatenate((img, img, img), axis=2) | |||||
| img = Image.fromarray(np.uint8(img)) | |||||
| inputs = self.transform_test(img) | |||||
| ncrops, c, h, w = inputs.shape | |||||
| inputs = inputs.view(-1, c, h, w) | |||||
| inputs = inputs.to(self.device) | |||||
| inputs = Variable(inputs, volatile=True) | |||||
| outputs = self.net(inputs) | |||||
| outputs_avg = outputs.view(ncrops, -1).mean(0) # avg over crops | |||||
| score = F.softmax(outputs_avg) | |||||
| _, predicted = torch.max(outputs_avg.data, 0) | |||||
| return score, predicted | |||||
| @@ -0,0 +1,118 @@ | |||||
| # The implementation is based on Facial-Expression-Recognition, available at | |||||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||||
| import numbers | |||||
| import types | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| def to_tensor(pic): | |||||
| # handle PIL Image | |||||
| if pic.mode == 'I': | |||||
| img = torch.from_numpy(np.array(pic, np.int32, copy=False)) | |||||
| elif pic.mode == 'I;16': | |||||
| img = torch.from_numpy(np.array(pic, np.int16, copy=False)) | |||||
| else: | |||||
| img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |||||
| # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK | |||||
| if pic.mode == 'YCbCr': | |||||
| nchannel = 3 | |||||
| elif pic.mode == 'I;16': | |||||
| nchannel = 1 | |||||
| else: | |||||
| nchannel = len(pic.mode) | |||||
| img = img.view(pic.size[1], pic.size[0], nchannel) | |||||
| # put it from HWC to CHW format | |||||
| # yikes, this transpose takes 80% of the loading time/CPU | |||||
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |||||
| if isinstance(img, torch.ByteTensor): | |||||
| return img.float().div(255) | |||||
| else: | |||||
| return img | |||||
| def center_crop(img, output_size): | |||||
| if isinstance(output_size, numbers.Number): | |||||
| output_size = (int(output_size), int(output_size)) | |||||
| w, h = img.size | |||||
| th, tw = output_size | |||||
| i = int(round((h - th) / 2.)) | |||||
| j = int(round((w - tw) / 2.)) | |||||
| return img.crop((j, i, j + tw, i + th)) | |||||
| def five_crop(img, size): | |||||
| if isinstance(size, numbers.Number): | |||||
| size = (int(size), int(size)) | |||||
| else: | |||||
| assert len( | |||||
| size) == 2, 'Please provide only two dimensions (h, w) for size.' | |||||
| w, h = img.size | |||||
| crop_h, crop_w = size | |||||
| if crop_w > w or crop_h > h: | |||||
| raise ValueError( | |||||
| 'Requested crop size {} is bigger than input size {}'.format( | |||||
| size, (h, w))) | |||||
| tl = img.crop((0, 0, crop_w, crop_h)) | |||||
| tr = img.crop((w - crop_w, 0, w, crop_h)) | |||||
| bl = img.crop((0, h - crop_h, crop_w, h)) | |||||
| br = img.crop((w - crop_w, h - crop_h, w, h)) | |||||
| center = center_crop(img, (crop_h, crop_w)) | |||||
| return (tl, tr, bl, br, center) | |||||
| class TenCrop(object): | |||||
| def __init__(self, size, vertical_flip=False): | |||||
| self.size = size | |||||
| if isinstance(size, numbers.Number): | |||||
| self.size = (int(size), int(size)) | |||||
| else: | |||||
| assert len( | |||||
| size | |||||
| ) == 2, 'Please provide only two dimensions (h, w) for size.' | |||||
| self.size = size | |||||
| self.vertical_flip = vertical_flip | |||||
| def __call__(self, img): | |||||
| first_five = five_crop(img, self.size) | |||||
| if self.vertical_flip: | |||||
| img = img.transpose(Image.FLIP_TOP_BOTTOM) | |||||
| else: | |||||
| img = img.transpose(Image.FLIP_LEFT_RIGHT) | |||||
| second_five = five_crop(img, self.size) | |||||
| return first_five + second_five | |||||
| class Compose(object): | |||||
| def __init__(self, transforms): | |||||
| self.transforms = transforms | |||||
| def __call__(self, img): | |||||
| for t in self.transforms: | |||||
| img = t(img) | |||||
| return img | |||||
| class ToTensor(object): | |||||
| def __call__(self, pic): | |||||
| return to_tensor(pic) | |||||
| class Lambda(object): | |||||
| def __init__(self, lambd): | |||||
| assert isinstance(lambd, types.LambdaType) | |||||
| self.lambd = lambd | |||||
| def __call__(self, img): | |||||
| return self.lambd(img) | |||||
| @@ -0,0 +1,40 @@ | |||||
| # The implementation is based on Facial-Expression-Recognition, available at | |||||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from torch.autograd import Variable | |||||
| from modelscope.utils.config import Config | |||||
| class VGG(nn.Module): | |||||
| def __init__(self, vgg_name, cfg_path): | |||||
| super(VGG, self).__init__() | |||||
| model_cfg = Config.from_file(cfg_path)['models'] | |||||
| self.features = self._make_layers(model_cfg[vgg_name]) | |||||
| self.classifier = nn.Linear(512, 7) | |||||
| def forward(self, x): | |||||
| out = self.features(x) | |||||
| out = out.view(out.size(0), -1) | |||||
| out = F.dropout(out, p=0.5, training=self.training) | |||||
| out = self.classifier(out) | |||||
| return out | |||||
| def _make_layers(self, cfg): | |||||
| layers = [] | |||||
| in_channels = 3 | |||||
| for x in cfg: | |||||
| if x == 'M': | |||||
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||||
| else: | |||||
| layers += [ | |||||
| nn.Conv2d(in_channels, x, kernel_size=3, padding=1), | |||||
| nn.BatchNorm2d(x), | |||||
| nn.ReLU(inplace=True) | |||||
| ] | |||||
| in_channels = x | |||||
| layers += [nn.AvgPool2d(kernel_size=1, stride=1)] | |||||
| return nn.Sequential(*layers) | |||||
| @@ -85,6 +85,14 @@ TASK_OUTPUTS = { | |||||
| Tasks.face_detection: | Tasks.face_detection: | ||||
| [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | ||||
| # facial expression recognition result for single sample | |||||
| # { | |||||
| # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], | |||||
| # "labels": ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] | |||||
| # } | |||||
| Tasks.facial_expression_recognition: | |||||
| [OutputKeys.SCORES, OutputKeys.LABELS], | |||||
| # face recognition result for single sample | # face recognition result for single sample | ||||
| # { | # { | ||||
| # "img_embedding": np.array with shape [1, D], | # "img_embedding": np.array with shape [1, D], | ||||
| @@ -103,6 +103,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_resnet_facedetection_scrfd10gkps'), | 'damo/cv_resnet_facedetection_scrfd10gkps'), | ||||
| Tasks.face_recognition: (Pipelines.face_recognition, | Tasks.face_recognition: (Pipelines.face_recognition, | ||||
| 'damo/cv_ir101_facerecognition_cfglint'), | 'damo/cv_ir101_facerecognition_cfglint'), | ||||
| Tasks.facial_expression_recognition: | |||||
| (Pipelines.facial_expression_recognition, | |||||
| 'damo/cv_vgg19_facial-expression-recognition_fer'), | |||||
| Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | ||||
| 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | ||||
| Tasks.video_multi_modal_embedding: | Tasks.video_multi_modal_embedding: | ||||
| @@ -0,0 +1,128 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||||
| from modelscope.models.cv.facial_expression_recognition.fer.facial_expression_recognition import \ | |||||
| FacialExpressionRecognition | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.facial_expression_recognition, | |||||
| module_name=Pipelines.facial_expression_recognition) | |||||
| class FacialExpressionRecognitionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a face detection pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| device = torch.device( | |||||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||||
| fer = FacialExpressionRecognition(model_path=ckpt_path, device=device) | |||||
| self.fer = fer | |||||
| self.device = device | |||||
| logger.info('load model done') | |||||
| # face detect pipeline | |||||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||||
| self.face_detection = pipeline( | |||||
| Tasks.face_detection, model=det_model_id) | |||||
| def _choose_face(self, | |||||
| det_result, | |||||
| min_face=10, | |||||
| top_face=1, | |||||
| center_face=False): | |||||
| ''' | |||||
| choose face with maximum area | |||||
| Args: | |||||
| det_result: output of face detection pipeline | |||||
| min_face: minimum size of valid face w/h | |||||
| top_face: take faces with top max areas | |||||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||||
| ''' | |||||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||||
| if bboxes.shape[0] == 0: | |||||
| logger.info('Warning: No face detected!') | |||||
| return None | |||||
| # face idx with enough size | |||||
| face_idx = [] | |||||
| for i in range(bboxes.shape[0]): | |||||
| box = bboxes[i] | |||||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||||
| face_idx += [i] | |||||
| if len(face_idx) == 0: | |||||
| logger.info( | |||||
| f'Warning: Face size not enough, less than {min_face}x{min_face}!' | |||||
| ) | |||||
| return None | |||||
| bboxes = bboxes[face_idx] | |||||
| landmarks = landmarks[face_idx] | |||||
| # find max faces | |||||
| boxes = np.array(bboxes) | |||||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||||
| sort_idx = np.argsort(area)[-top_face:] | |||||
| # find center face | |||||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||||
| min_dist = float('inf') | |||||
| sel_idx = -1 | |||||
| for _idx in sort_idx: | |||||
| box = boxes[_idx] | |||||
| dist = np.square( | |||||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||||
| if dist < min_dist: | |||||
| min_dist = dist | |||||
| sel_idx = _idx | |||||
| sort_idx = [sel_idx] | |||||
| main_idx = sort_idx[-1] | |||||
| return bboxes[main_idx], landmarks[main_idx] | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = img[:, :, ::-1] | |||||
| det_result = self.face_detection(img.copy()) | |||||
| rtn = self._choose_face(det_result) | |||||
| face_img = None | |||||
| if rtn is not None: | |||||
| _, face_lmks = rtn | |||||
| face_lmks = face_lmks.reshape(5, 2) | |||||
| face_img, _ = align_face(img, (112, 112), face_lmks) | |||||
| face_img = face_img.astype(np.float32) | |||||
| result = {} | |||||
| result['img'] = face_img | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.fer(input) | |||||
| assert result is not None | |||||
| scores = result[0].tolist() | |||||
| labels = result[1].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -20,6 +20,7 @@ class CVTasks(object): | |||||
| animal_recognition = 'animal-recognition' | animal_recognition = 'animal-recognition' | ||||
| face_detection = 'face-detection' | face_detection = 'face-detection' | ||||
| face_recognition = 'face-recognition' | face_recognition = 'face-recognition' | ||||
| facial_expression_recognition = 'facial-expression-recognition' | |||||
| face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
| human_detection = 'human-detection' | human_detection = 'human-detection' | ||||
| human_object_interaction = 'human-object-interaction' | human_object_interaction = 'human-object-interaction' | ||||
| @@ -89,6 +89,26 @@ def draw_keypoints(output, original_image): | |||||
| return image | return image | ||||
| def draw_facial_expression_result(img_path, facial_expression_result): | |||||
| label_idx = facial_expression_result[OutputKeys.LABELS] | |||||
| map_list = [ | |||||
| 'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' | |||||
| ] | |||||
| label = map_list[label_idx] | |||||
| img = cv2.imread(img_path) | |||||
| assert img is not None, f"Can't read img: {img_path}" | |||||
| cv2.putText( | |||||
| img, | |||||
| 'facial expression: {}'.format(label), (10, 10), | |||||
| 1, | |||||
| 1.0, (0, 255, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| print('facial expression: {}'.format(label)) | |||||
| return img | |||||
| def draw_face_detection_result(img_path, detection_result): | def draw_face_detection_result(img_path, detection_result): | ||||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | bboxes = np.array(detection_result[OutputKeys.BOXES]) | ||||
| kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | ||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_facial_expression_result | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class FacialExpressionRecognitionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_vgg19_facial-expression-recognition_fer' | |||||
| def show_result(self, img_path, facial_expression_result): | |||||
| img = draw_facial_expression_result(img_path, facial_expression_result) | |||||
| cv2.imwrite('result.png', img) | |||||
| print(f'output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| fer = pipeline( | |||||
| Tasks.facial_expression_recognition, model=self.model_id) | |||||
| img_path = 'data/test/images/facial_expression_recognition.jpg' | |||||
| result = fer(img_path) | |||||
| self.show_result(img_path, result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||