完成Maas-cv CR自查; 新增个Task,已经跟产品确认可以增加,正在走流程中,目前还不在https://aone.alibaba-inc.com/v2/project/1181559/req#viewIdentifier=d7f112f9d023e2108fa1b0d8这里,后续会增加过来 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9976346master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:bdb1cef5a5fd5f938a856311011c4820ddc45946a470b9929c61e59b6a065633 | |||
| size 161535 | |||
| @@ -32,6 +32,7 @@ class Models(object): | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| resnet50_bert = 'resnet50-bert' | |||
| fer = 'fer' | |||
| retinaface = 'retinaface' | |||
| shop_segmentation = 'shop-segmentation' | |||
| @@ -119,6 +120,7 @@ class Pipelines(object): | |||
| salient_detection = 'u2net-salient-detection' | |||
| image_classification = 'image-classification' | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||
| live_category = 'live-category' | |||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | |||
| @@ -0,0 +1,72 @@ | |||
| # The implementation is based on Facial-Expression-Recognition, available at | |||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.backends.cudnn as cudnn | |||
| import torch.nn.functional as F | |||
| from PIL import Image | |||
| from torch.autograd import Variable | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from . import transforms | |||
| from .vgg import VGG | |||
| @MODELS.register_module( | |||
| Tasks.facial_expression_recognition, module_name=Models.fer) | |||
| class FacialExpressionRecognition(TorchModel): | |||
| def __init__(self, model_path, device='cuda'): | |||
| super().__init__(model_path) | |||
| torch.set_grad_enabled(False) | |||
| cudnn.benchmark = True | |||
| self.model_path = model_path | |||
| self.device = device | |||
| self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, | |||
| ModelFile.CONFIGURATION) | |||
| self.net = VGG('VGG19', cfg_path=self.cfg_path) | |||
| self.load_model() | |||
| self.net = self.net.to(device) | |||
| self.transform_test = transforms.Compose([ | |||
| transforms.TenCrop(44), | |||
| transforms.Lambda(lambda crops: torch.stack( | |||
| [transforms.ToTensor()(crop) for crop in crops])), | |||
| ]) | |||
| self.mean = np.array([[104, 117, 123]]) | |||
| def load_model(self, load_to_cpu=False): | |||
| pretrained_dict = torch.load( | |||
| self.model_path, map_location=torch.device('cpu')) | |||
| self.net.load_state_dict(pretrained_dict['net'], strict=True) | |||
| self.net.eval() | |||
| def forward(self, input): | |||
| img = input['img'] | |||
| img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2GRAY) | |||
| img = cv2.resize(img, (48, 48)) | |||
| img = img[:, :, np.newaxis] | |||
| img = np.concatenate((img, img, img), axis=2) | |||
| img = Image.fromarray(np.uint8(img)) | |||
| inputs = self.transform_test(img) | |||
| ncrops, c, h, w = inputs.shape | |||
| inputs = inputs.view(-1, c, h, w) | |||
| inputs = inputs.to(self.device) | |||
| inputs = Variable(inputs, volatile=True) | |||
| outputs = self.net(inputs) | |||
| outputs_avg = outputs.view(ncrops, -1).mean(0) # avg over crops | |||
| score = F.softmax(outputs_avg) | |||
| _, predicted = torch.max(outputs_avg.data, 0) | |||
| return score, predicted | |||
| @@ -0,0 +1,118 @@ | |||
| # The implementation is based on Facial-Expression-Recognition, available at | |||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||
| import numbers | |||
| import types | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| def to_tensor(pic): | |||
| # handle PIL Image | |||
| if pic.mode == 'I': | |||
| img = torch.from_numpy(np.array(pic, np.int32, copy=False)) | |||
| elif pic.mode == 'I;16': | |||
| img = torch.from_numpy(np.array(pic, np.int16, copy=False)) | |||
| else: | |||
| img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |||
| # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK | |||
| if pic.mode == 'YCbCr': | |||
| nchannel = 3 | |||
| elif pic.mode == 'I;16': | |||
| nchannel = 1 | |||
| else: | |||
| nchannel = len(pic.mode) | |||
| img = img.view(pic.size[1], pic.size[0], nchannel) | |||
| # put it from HWC to CHW format | |||
| # yikes, this transpose takes 80% of the loading time/CPU | |||
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |||
| if isinstance(img, torch.ByteTensor): | |||
| return img.float().div(255) | |||
| else: | |||
| return img | |||
| def center_crop(img, output_size): | |||
| if isinstance(output_size, numbers.Number): | |||
| output_size = (int(output_size), int(output_size)) | |||
| w, h = img.size | |||
| th, tw = output_size | |||
| i = int(round((h - th) / 2.)) | |||
| j = int(round((w - tw) / 2.)) | |||
| return img.crop((j, i, j + tw, i + th)) | |||
| def five_crop(img, size): | |||
| if isinstance(size, numbers.Number): | |||
| size = (int(size), int(size)) | |||
| else: | |||
| assert len( | |||
| size) == 2, 'Please provide only two dimensions (h, w) for size.' | |||
| w, h = img.size | |||
| crop_h, crop_w = size | |||
| if crop_w > w or crop_h > h: | |||
| raise ValueError( | |||
| 'Requested crop size {} is bigger than input size {}'.format( | |||
| size, (h, w))) | |||
| tl = img.crop((0, 0, crop_w, crop_h)) | |||
| tr = img.crop((w - crop_w, 0, w, crop_h)) | |||
| bl = img.crop((0, h - crop_h, crop_w, h)) | |||
| br = img.crop((w - crop_w, h - crop_h, w, h)) | |||
| center = center_crop(img, (crop_h, crop_w)) | |||
| return (tl, tr, bl, br, center) | |||
| class TenCrop(object): | |||
| def __init__(self, size, vertical_flip=False): | |||
| self.size = size | |||
| if isinstance(size, numbers.Number): | |||
| self.size = (int(size), int(size)) | |||
| else: | |||
| assert len( | |||
| size | |||
| ) == 2, 'Please provide only two dimensions (h, w) for size.' | |||
| self.size = size | |||
| self.vertical_flip = vertical_flip | |||
| def __call__(self, img): | |||
| first_five = five_crop(img, self.size) | |||
| if self.vertical_flip: | |||
| img = img.transpose(Image.FLIP_TOP_BOTTOM) | |||
| else: | |||
| img = img.transpose(Image.FLIP_LEFT_RIGHT) | |||
| second_five = five_crop(img, self.size) | |||
| return first_five + second_five | |||
| class Compose(object): | |||
| def __init__(self, transforms): | |||
| self.transforms = transforms | |||
| def __call__(self, img): | |||
| for t in self.transforms: | |||
| img = t(img) | |||
| return img | |||
| class ToTensor(object): | |||
| def __call__(self, pic): | |||
| return to_tensor(pic) | |||
| class Lambda(object): | |||
| def __init__(self, lambd): | |||
| assert isinstance(lambd, types.LambdaType) | |||
| self.lambd = lambd | |||
| def __call__(self, img): | |||
| return self.lambd(img) | |||
| @@ -0,0 +1,40 @@ | |||
| # The implementation is based on Facial-Expression-Recognition, available at | |||
| # https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.autograd import Variable | |||
| from modelscope.utils.config import Config | |||
| class VGG(nn.Module): | |||
| def __init__(self, vgg_name, cfg_path): | |||
| super(VGG, self).__init__() | |||
| model_cfg = Config.from_file(cfg_path)['models'] | |||
| self.features = self._make_layers(model_cfg[vgg_name]) | |||
| self.classifier = nn.Linear(512, 7) | |||
| def forward(self, x): | |||
| out = self.features(x) | |||
| out = out.view(out.size(0), -1) | |||
| out = F.dropout(out, p=0.5, training=self.training) | |||
| out = self.classifier(out) | |||
| return out | |||
| def _make_layers(self, cfg): | |||
| layers = [] | |||
| in_channels = 3 | |||
| for x in cfg: | |||
| if x == 'M': | |||
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||
| else: | |||
| layers += [ | |||
| nn.Conv2d(in_channels, x, kernel_size=3, padding=1), | |||
| nn.BatchNorm2d(x), | |||
| nn.ReLU(inplace=True) | |||
| ] | |||
| in_channels = x | |||
| layers += [nn.AvgPool2d(kernel_size=1, stride=1)] | |||
| return nn.Sequential(*layers) | |||
| @@ -85,6 +85,14 @@ TASK_OUTPUTS = { | |||
| Tasks.face_detection: | |||
| [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||
| # facial expression recognition result for single sample | |||
| # { | |||
| # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], | |||
| # "labels": ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] | |||
| # } | |||
| Tasks.facial_expression_recognition: | |||
| [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # face recognition result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| @@ -103,6 +103,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
| Tasks.face_recognition: (Pipelines.face_recognition, | |||
| 'damo/cv_ir101_facerecognition_cfglint'), | |||
| Tasks.facial_expression_recognition: | |||
| (Pipelines.facial_expression_recognition, | |||
| 'damo/cv_vgg19_facial-expression-recognition_fer'), | |||
| Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, | |||
| 'damo/cv_mobilenet_face-2d-keypoints_alignment'), | |||
| Tasks.video_multi_modal_embedding: | |||
| @@ -0,0 +1,128 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_recognition.align_face import align_face | |||
| from modelscope.models.cv.facial_expression_recognition.fer.facial_expression_recognition import \ | |||
| FacialExpressionRecognition | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.facial_expression_recognition, | |||
| module_name=Pipelines.facial_expression_recognition) | |||
| class FacialExpressionRecognitionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a face detection pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| device = torch.device( | |||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
| fer = FacialExpressionRecognition(model_path=ckpt_path, device=device) | |||
| self.fer = fer | |||
| self.device = device | |||
| logger.info('load model done') | |||
| # face detect pipeline | |||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
| self.face_detection = pipeline( | |||
| Tasks.face_detection, model=det_model_id) | |||
| def _choose_face(self, | |||
| det_result, | |||
| min_face=10, | |||
| top_face=1, | |||
| center_face=False): | |||
| ''' | |||
| choose face with maximum area | |||
| Args: | |||
| det_result: output of face detection pipeline | |||
| min_face: minimum size of valid face w/h | |||
| top_face: take faces with top max areas | |||
| center_face: choose the most centerd face from multi faces, only valid if top_face > 1 | |||
| ''' | |||
| bboxes = np.array(det_result[OutputKeys.BOXES]) | |||
| landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) | |||
| if bboxes.shape[0] == 0: | |||
| logger.info('Warning: No face detected!') | |||
| return None | |||
| # face idx with enough size | |||
| face_idx = [] | |||
| for i in range(bboxes.shape[0]): | |||
| box = bboxes[i] | |||
| if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: | |||
| face_idx += [i] | |||
| if len(face_idx) == 0: | |||
| logger.info( | |||
| f'Warning: Face size not enough, less than {min_face}x{min_face}!' | |||
| ) | |||
| return None | |||
| bboxes = bboxes[face_idx] | |||
| landmarks = landmarks[face_idx] | |||
| # find max faces | |||
| boxes = np.array(bboxes) | |||
| area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |||
| sort_idx = np.argsort(area)[-top_face:] | |||
| # find center face | |||
| if top_face > 1 and center_face and bboxes.shape[0] > 1: | |||
| img_center = [img.shape[1] // 2, img.shape[0] // 2] | |||
| min_dist = float('inf') | |||
| sel_idx = -1 | |||
| for _idx in sort_idx: | |||
| box = boxes[_idx] | |||
| dist = np.square( | |||
| np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( | |||
| np.abs((box[1] + box[3]) / 2 - img_center[1])) | |||
| if dist < min_dist: | |||
| min_dist = dist | |||
| sel_idx = _idx | |||
| sort_idx = [sel_idx] | |||
| main_idx = sort_idx[-1] | |||
| return bboxes[main_idx], landmarks[main_idx] | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img[:, :, ::-1] | |||
| det_result = self.face_detection(img.copy()) | |||
| rtn = self._choose_face(det_result) | |||
| face_img = None | |||
| if rtn is not None: | |||
| _, face_lmks = rtn | |||
| face_lmks = face_lmks.reshape(5, 2) | |||
| face_img, _ = align_face(img, (112, 112), face_lmks) | |||
| face_img = face_img.astype(np.float32) | |||
| result = {} | |||
| result['img'] = face_img | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = self.fer(input) | |||
| assert result is not None | |||
| scores = result[0].tolist() | |||
| labels = result[1].tolist() | |||
| return { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -20,6 +20,7 @@ class CVTasks(object): | |||
| animal_recognition = 'animal-recognition' | |||
| face_detection = 'face-detection' | |||
| face_recognition = 'face-recognition' | |||
| facial_expression_recognition = 'facial-expression-recognition' | |||
| face_2d_keypoints = 'face-2d-keypoints' | |||
| human_detection = 'human-detection' | |||
| human_object_interaction = 'human-object-interaction' | |||
| @@ -89,6 +89,26 @@ def draw_keypoints(output, original_image): | |||
| return image | |||
| def draw_facial_expression_result(img_path, facial_expression_result): | |||
| label_idx = facial_expression_result[OutputKeys.LABELS] | |||
| map_list = [ | |||
| 'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' | |||
| ] | |||
| label = map_list[label_idx] | |||
| img = cv2.imread(img_path) | |||
| assert img is not None, f"Can't read img: {img_path}" | |||
| cv2.putText( | |||
| img, | |||
| 'facial expression: {}'.format(label), (10, 10), | |||
| 1, | |||
| 1.0, (0, 255, 0), | |||
| thickness=1, | |||
| lineType=8) | |||
| print('facial expression: {}'.format(label)) | |||
| return img | |||
| def draw_face_detection_result(img_path, detection_result): | |||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||
| kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import draw_facial_expression_result | |||
| from modelscope.utils.test_utils import test_level | |||
| class FacialExpressionRecognitionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_vgg19_facial-expression-recognition_fer' | |||
| def show_result(self, img_path, facial_expression_result): | |||
| img = draw_facial_expression_result(img_path, facial_expression_result) | |||
| cv2.imwrite('result.png', img) | |||
| print(f'output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| fer = pipeline( | |||
| Tasks.facial_expression_recognition, model=self.model_id) | |||
| img_path = 'data/test/images/facial_expression_recognition.jpg' | |||
| result = fer(img_path) | |||
| self.show_result(img_path, result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||