diff --git a/data/test/images/hand_static.jpg b/data/test/images/hand_static.jpg new file mode 100644 index 00000000..43ae28b1 --- /dev/null +++ b/data/test/images/hand_static.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94b8e281d77ee6d3ea2a8a0c9408ecdbd29fe75f33ea5399b6ea00070ba77bd6 +size 13090 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 29a35fbe..5870ebe3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -39,6 +39,7 @@ class Models(object): mtcnn = 'mtcnn' ulfd = 'ulfd' video_inpainting = 'video-inpainting' + hand_static = 'hand-static' # EasyCV models yolox = 'YOLOX' @@ -173,6 +174,7 @@ class Pipelines(object): movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' video_inpainting = 'video-inpainting' + hand_static = 'hand-static' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/hand_static/__init__.py b/modelscope/models/cv/hand_static/__init__.py new file mode 100644 index 00000000..654d2acb --- /dev/null +++ b/modelscope/models/cv/hand_static/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_model import HandStatic + +else: + _import_structure = {'hand_model': ['HandStatic']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/hand_static/hand_model.py b/modelscope/models/cv/hand_static/hand_model.py new file mode 100644 index 00000000..38517307 --- /dev/null +++ b/modelscope/models/cv/hand_static/hand_model.py @@ -0,0 +1,93 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torchvision.transforms import transforms + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .networks import StaticGestureNet + +logger = get_logger() + +map_idx = { + 0: 'unrecog', + 1: 'one', + 2: 'two', + 3: 'bixin', + 4: 'yaogun', + 5: 'zan', + 6: 'fist', + 7: 'ok', + 8: 'tuoju', + 9: 'd_bixin', + 10: 'd_fist_left', + 11: 'd_fist_right', + 12: 'd_hand', + 13: 'fashe', + 14: 'five', + 15: 'nohand' +} + +img_size = [112, 112] + +spatial_transform = transforms.Compose([ + transforms.Resize(img_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +]) + + +@MODELS.register_module(Tasks.hand_static, module_name=Models.hand_static) +class HandStatic(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = StaticGestureNet() + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=self.device) + + self.model.load_state_dict(self.params) + self.model.to(self.device) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and self.device == 'cuda': + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + def forward(self, x): + pred_result = self.model(x) + return pred_result + + +def infer(img_path, model, device): + + img = Image.open(img_path) + clip = spatial_transform(img) + clip = clip.unsqueeze(0).to(device).float() + outputs = model(clip) + predicted = int(outputs.max(1)[1]) + pred_result = map_idx.get(predicted) + logger.info('pred result: {}'.format(pred_result)) + + return pred_result diff --git a/modelscope/models/cv/hand_static/networks.py b/modelscope/models/cv/hand_static/networks.py new file mode 100644 index 00000000..6cf46f5d --- /dev/null +++ b/modelscope/models/cv/hand_static/networks.py @@ -0,0 +1,358 @@ +""" HandStatic +The implementation here is modified based on MobileFaceNet, +originally Apache 2.0 License and publicly avaialbe at https://github.com/xuexingyu24/MobileFaceNet_Tutorial_Pytorch +""" + +import os + +import torch +import torch.nn as nn +import torchvision +import torchvision.models as models +from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, + Dropout, Linear, MaxPool2d, Module, PReLU, ReLU, + Sequential, Sigmoid) + + +class StaticGestureNet(torch.nn.Module): + + def __init__(self, train=True): + super().__init__() + + model = MobileFaceNet(512) + self.feature_extractor = model + self.fc_layer = torch.nn.Sequential( + nn.Linear(512, 128), nn.Softplus(), nn.Linear(128, 15)) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs): + out = self.feature_extractor(inputs) + out = self.fc_layer(out) + out = self.sigmoid(out) + return out + + +class Flatten(Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class SEModule(Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class BottleneckIR(Module): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class BottleneckIRSE(Module): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + return blocks + + +class Backbone(Module): + + def __init__(self, num_layers, drop_ratio, mode='ir'): + super(Backbone, self).__init__() + assert num_layers in [50, 100, + 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 7 * 7, 512), BatchNorm1d(512)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +class ConvBlock(Module): + + def __init__(self, + in_c, + out_c, + kernel=(1, 1), + stride=(1, 1), + padding=(0, 0), + groups=1): + super(ConvBlock, self).__init__() + self.conv = Conv2d( + in_c, + out_channels=out_c, + kernel_size=kernel, + groups=groups, + stride=stride, + padding=padding, + bias=False) + self.bn = BatchNorm2d(out_c) + self.prelu = PReLU(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.prelu(x) + return x + + +class LinearBlock(Module): + + def __init__(self, + in_c, + out_c, + kernel=(1, 1), + stride=(1, 1), + padding=(0, 0), + groups=1): + super(LinearBlock, self).__init__() + self.conv = Conv2d( + in_c, + out_channels=out_c, + kernel_size=kernel, + groups=groups, + stride=stride, + padding=padding, + bias=False) + self.bn = BatchNorm2d(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DepthWise(Module): + + def __init__(self, + in_c, + out_c, + residual=False, + kernel=(3, 3), + stride=(2, 2), + padding=(1, 1), + groups=1): + super(DepthWise, self).__init__() + self.conv = ConvBlock( + in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + self.conv_dw = ConvBlock( + groups, + groups, + groups=groups, + kernel=kernel, + padding=padding, + stride=stride) + self.project = LinearBlock( + groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + self.residual = residual + + def forward(self, x): + if self.residual: + short_cut = x + x = self.conv(x) + x = self.conv_dw(x) + x = self.project(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + + def __init__(self, + c, + num_block, + groups, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append( + DepthWise( + c, + c, + residual=True, + kernel=kernel, + padding=padding, + stride=stride, + groups=groups)) + self.model = Sequential(*modules) + + def forward(self, x): + return self.model(x) + + +class MobileFaceNet(Module): + + def __init__(self, embedding_size): + super(MobileFaceNet, self).__init__() + self.conv1 = ConvBlock( + 3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) + self.conv2_dw = ConvBlock( + 64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) + self.conv_23 = DepthWise( + 64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) + self.conv_3 = Residual( + 64, + num_block=4, + groups=128, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_34 = DepthWise( + 64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) + self.conv_4 = Residual( + 128, + num_block=6, + groups=256, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_45 = DepthWise( + 128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) + self.conv_5 = Residual( + 128, + num_block=2, + groups=256, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_6_sep = ConvBlock( + 128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.conv_6_dw = LinearBlock( + 512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)) + self.conv_6_flatten = Flatten() + self.linear = Linear(512, embedding_size, bias=False) + self.bn = BatchNorm1d(embedding_size) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2_dw(out) + out = self.conv_23(out) + out = self.conv_3(out) + out = self.conv_34(out) + out = self.conv_4(out) + out = self.conv_45(out) + out = self.conv_5(out) + out = self.conv_6_sep(out) + out = self.conv_6_dw(out) + out = self.conv_6_flatten(out) + out = self.linear(out) + return l2_norm(out) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index b96f38d3..ce9e8d07 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -632,5 +632,9 @@ TASK_OUTPUTS = { # { # 'output': ['Done' / 'Decode_Error'] # } - Tasks.video_inpainting: [OutputKeys.OUTPUT] + Tasks.video_inpainting: [OutputKeys.OUTPUT], + # { + # 'output': ['bixin'] + # } + Tasks.hand_static: [OutputKeys.OUTPUT] } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5e244b27..51d50d51 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -178,6 +178,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_vitb16_segmentation_shop-seg'), Tasks.video_inpainting: (Pipelines.video_inpainting, 'damo/cv_video-inpainting'), + Tasks.hand_static: (Pipelines.hand_static, + 'damo/cv_mobileface_hand-static'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index a9dc05f2..55bad09a 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -52,7 +52,8 @@ if TYPE_CHECKING: from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline - from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline + from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin + from .hand_static_pipeline import HandStaticPipeline else: _import_structure = { @@ -119,6 +120,7 @@ else: 'facial_expression_recognition_pipelin': ['FacialExpressionRecognitionPipeline'], 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], + 'hand_static_pipeline': ['HandStaticPipeline'], } import sys diff --git a/modelscope/pipelines/cv/hand_static_pipeline.py b/modelscope/pipelines/cv/hand_static_pipeline.py new file mode 100644 index 00000000..1219c873 --- /dev/null +++ b/modelscope/pipelines/cv/hand_static_pipeline.py @@ -0,0 +1,37 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.hand_static import hand_model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.hand_static, module_name=Pipelines.hand_static) +class HandStaticPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create hand static pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = hand_model.infer(input['img_path'], self.model, self.device) + return {OutputKeys.OUTPUT: result} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index de3d933f..75add1d9 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -42,6 +42,7 @@ class CVTasks(object): portrait_matting = 'portrait-matting' text_driven_segmentation = 'text-driven-segmentation' shop_segmentation = 'shop-segmentation' + hand_static = 'hand-static' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_hand_static.py b/tests/pipelines/test_hand_static.py new file mode 100644 index 00000000..37181899 --- /dev/null +++ b/tests/pipelines/test_hand_static.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class HandStaticTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_mobileface_hand-static' + self.input = {'img_path': 'data/test/images/hand_static.jpg'} + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + hand_static = pipeline(Tasks.hand_static, model=self.model) + self.pipeline_inference(hand_static, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + hand_static = pipeline(Tasks.hand_static) + self.pipeline_inference(hand_static, self.input) + + +if __name__ == '__main__': + unittest.main()