diff --git a/data/test/images/face_emotion.jpg b/data/test/images/face_emotion.jpg new file mode 100644 index 00000000..54f22280 --- /dev/null +++ b/data/test/images/face_emotion.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:712b5525e37080d33f62d6657609dbef20e843ccc04ee5c788ea11aa7c08545e +size 123341 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 54e09f7a..ae8b5297 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -41,6 +41,7 @@ class Models(object): video_inpainting = 'video-inpainting' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' # EasyCV models yolox = 'YOLOX' @@ -183,6 +184,7 @@ class Pipelines(object): pst_action_recognition = 'patchshift-action-recognition' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/face_emotion/__init__.py b/modelscope/models/cv/face_emotion/__init__.py new file mode 100644 index 00000000..2a13ea42 --- /dev/null +++ b/modelscope/models/cv/face_emotion/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .emotion_model import EfficientNetForFaceEmotion + +else: + _import_structure = {'emotion_model': ['EfficientNetForFaceEmotion']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_emotion/efficient/__init__.py b/modelscope/models/cv/face_emotion/efficient/__init__.py new file mode 100644 index 00000000..e8fc91a4 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/__init__.py @@ -0,0 +1,6 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +from .model import VALID_MODELS, EfficientNet +from .utils import (BlockArgs, BlockDecoder, GlobalParams, efficientnet, + get_model_params) diff --git a/modelscope/models/cv/face_emotion/efficient/model.py b/modelscope/models/cv/face_emotion/efficient/model.py new file mode 100644 index 00000000..db303016 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/model.py @@ -0,0 +1,380 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import (MemoryEfficientSwish, Swish, calculate_output_image_size, + drop_connect, efficientnet_params, get_model_params, + get_same_padding_conv2d, load_pretrained_weights, + round_filters, round_repeats) + +VALID_MODELS = ('efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', + 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', + 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b8', + 'efficientnet-l2') + + +class MBConvBlock(nn.Module): + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio + is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip + + inp = self._block_args.input_filters + oup = self._block_args.input_filters * self._block_args.expand_ratio + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d( + in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, + kernel_size=k, + stride=s, + bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max( + 1, + int(self._block_args.input_filters + * self._block_args.se_ratio)) + self._se_reduce = Conv2d( + in_channels=oup, + out_channels=num_squeezed_channels, + kernel_size=1) + self._se_expand = Conv2d( + in_channels=num_squeezed_channels, + out_channels=oup, + kernel_size=1) + + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d( + in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d( + num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + Returns: + Output of this block after processing. + """ + + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + x = self._project_conv(x) + x = self._bn2(x) + + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect( + x, p=drop_connect_rate, training=self.training) + x = x + inputs + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + in_channels = 3 + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, + self._global_params), + output_filters=round_filters(block_args.output_filters, + self._global_params), + num_repeat=round_repeats(block_args.num_repeat, + self._global_params)) + + self._blocks.append( + MBConvBlock( + block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, + block_args.stride) + if block_args.num_repeat > 1: + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock( + block_args, self._global_params, + image_size=image_size)) + + in_channels = block_args.output_filters + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d( + in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + Args: + inputs (tensor): Input tensor. + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len( + self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + prev_x = x + + x = self._swish(self._bn1(self._conv_head(x))) + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + Args: + inputs (tensor): Input tensor. + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + x = self._swish(self._bn0(self._conv_stem(inputs))) + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + Args: + inputs (tensor): Input tensor. + Returns: + Output of this model after processing. + """ + x = self.extract_features(inputs) + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, + override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained(cls, + model_name, + weights_path=None, + advprop=False, + in_channels=3, + num_classes=1000, + **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name( + model_name, num_classes=num_classes, **override_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + Args: + model_name (str): Name for efficientnet. + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + Args: + model_name (str): Name for efficientnet. + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError('model_name should be one of: ' + + ', '.join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d( + image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/modelscope/models/cv/face_emotion/efficient/utils.py b/modelscope/models/cv/face_emotion/efficient/utils.py new file mode 100644 index 00000000..6cae70fc --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/utils.py @@ -0,0 +1,559 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import collections +import math +import re +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top' +]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'input_filters', + 'output_filters', 'se_ratio', 'id_skip' +]) + +GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) + +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + + class Swish(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class SwishImplementation(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], + dtype=inputs.dtype, + device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + Args: + x (int, tuple or list): Data size. + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size( + input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + a1 = (oh - 1) * self.stride[0] + pad_h = max(a1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + a2 = (ow - 1) * self.stride[1] + pad_w = max(a2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + image_size=None, + **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, + **kwargs) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + b1 = (oh - 1) * self.stride[0] + pad_h = max(b1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + b2 = (ow - 1) * self.stride[1] + pad_w = max(b2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + kernel_size, + stride, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, + return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + c1 = (oh - 1) * self.stride[0] + pad_h = max(c1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + c2 = (ow - 1) * self.stride[1] + pad_w = max(c2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + d1 = (oh - 1) * self.stride[0] + pad_h = max(d1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + d2 = (ow - 1) * self.stride[1] + pad_w = max(d2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) + or (len(options['s']) == 2 + and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + Args: + block (namedtuple): A BlockArgs type argument. + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + Args: + model_name (str): Model name to be queried. + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + 'efficientnet-b0': (1.0, 1.0, 112, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + Meaning as the name suggests. + Returns: + blocks_args, global_params. + """ + + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet( + width_coefficient=w, + depth_coefficient=d, + dropout_rate=p, + image_size=s) + else: + raise NotImplementedError( + 'model name is not pre-defined: {}'.format(model_name)) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +def load_pretrained_weights(model, + model_name, + weights_path=None, + load_fc=True, + advprop=False, + verbose=True): + """Loads pretrained weights from weights path or download using url. + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set([ + '_fc.weight', '_fc.bias' + ]), 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/modelscope/models/cv/face_emotion/emotion_infer.py b/modelscope/models/cv/face_emotion/emotion_infer.py new file mode 100644 index 00000000..e3398592 --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_infer.py @@ -0,0 +1,67 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import torch +from PIL import Image +from torch import nn +from torchvision import transforms + +from modelscope.utils.logger import get_logger +from .face_alignment.face_align import face_detection_PIL_v2 + +logger = get_logger() + + +def transform_PIL(img_pil): + val_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + return val_transforms(img_pil) + + +index2AU = [1, 2, 4, 6, 7, 10, 12, 15, 23, 24, 25, 26] +emotion_list = [ + 'Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise' +] + + +def inference(image_path, model, face_model, score_thre=0.5, GPU=0): + image = Image.open(image_path).convert('RGB') + + face, bbox = face_detection_PIL_v2(image, face_model) + if bbox is None: + logger.warn('no face detected!') + result = {'emotion_result': None, 'box': None} + return result + + face = transform_PIL(face) + face = face.unsqueeze(0) + if torch.cuda.is_available(): + face = face.cuda(GPU) + logits_AU, logits_emotion = model(face) + logits_AU = torch.sigmoid(logits_AU) + logits_emotion = nn.functional.softmax(logits_emotion, 1) + + _, index_list = logits_emotion.max(1) + emotion_index = index_list[0].data.item() + prob = logits_emotion[0][emotion_index] + if prob > score_thre and emotion_index != 3: + cur_emotion = emotion_list[emotion_index] + else: + cur_emotion = 'Neutral' + + logits_AU = logits_AU[0] + au_ouput = torch.zeros_like(logits_AU) + au_ouput[logits_AU >= score_thre] = 1 + au_ouput[logits_AU < score_thre] = 0 + + au_ouput = au_ouput.int() + + cur_au_list = [] + for idx in range(au_ouput.shape[0]): + if au_ouput[idx] == 1: + au = index2AU[idx] + cur_au_list.append(au) + cur_au_list.sort() + result = (cur_emotion, bbox) + return result diff --git a/modelscope/models/cv/face_emotion/emotion_model.py b/modelscope/models/cv/face_emotion/emotion_model.py new file mode 100644 index 00000000..f8df9c37 --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_model.py @@ -0,0 +1,96 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.face_emotion.efficient import EfficientNet +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@MODELS.register_module(Tasks.face_emotion, module_name=Models.face_emotion) +class EfficientNetForFaceEmotion(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = FaceEmotionModel( + name='efficientnet-b0', num_embed=512, num_au=12, num_emotion=7) + + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=self.device) + + state_dict = pretrained_params['model'] + new_state = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + new_state[k] = v + + self.model.load_state_dict(new_state) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + logits_au, logits_emotion = self.model(x) + return logits_au, logits_emotion + + +class FaceEmotionModel(nn.Module): + + def __init__(self, + name='efficientnet-b0', + num_embed=512, + num_au=12, + num_emotion=7): + super(FaceEmotionModel, self).__init__() + self.backbone = EfficientNet.from_pretrained( + name, weights_path=None, advprop=True) + self.average_pool = nn.AdaptiveAvgPool2d(1) + self.embed = nn.Linear(self.backbone._fc.weight.data.shape[1], + num_embed) + self.features = nn.BatchNorm1d(num_embed) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + self.fc_au = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_au), + ) + self.fc_emotion = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_emotion), + ) + + def feat_single_img(self, x): + x = self.backbone.extract_features(x) + x = self.average_pool(x) + x = x.flatten(1) + x = self.embed(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.feat_single_img(x) + logits_au = self.fc_au(x) + att_au = torch.sigmoid(logits_au).unsqueeze(-1) + x = x.unsqueeze(1) + emotion_vec_list = torch.matmul(att_au, x) + emotion_vec = emotion_vec_list.sum(1) + logits_emotion = self.fc_emotion(emotion_vec) + return logits_au, logits_emotion diff --git a/modelscope/models/cv/face_emotion/face_alignment/__init__.py b/modelscope/models/cv/face_emotion/face_alignment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_emotion/face_alignment/face.py b/modelscope/models/cv/face_emotion/face_alignment/face.py new file mode 100644 index 00000000..a362bddc --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face.py @@ -0,0 +1,79 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os + +import cv2 +import numpy as np +import tensorflow as tf + + +def init(mod): + PATH_TO_CKPT = mod + net = tf.Graph() + with net.as_default(): + od_graph_def = tf.GraphDef() + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + sess = tf.Session(graph=net, config=config) + return sess, net + + +def filter_bboxes_confs(shape, + imgsBboxes, + imgsConfs, + single=False, + thresh=0.5): + [w, h] = shape + if single: + bboxes, confs = [], [] + for y in range(len(imgsBboxes)): + if imgsConfs[y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int(w * x2), int( + h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[y]) + return bboxes, confs + else: + retImgsBboxes, retImgsConfs = [], [] + for x in range(len(imgsBboxes)): + bboxes, confs = [], [] + for y in range(len(imgsBboxes[x])): + if imgsConfs[x][y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[x][y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int( + w * x2), int(h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[x][y]) + retImgsBboxes.append(bboxes) + retImgsConfs.append(confs) + return retImgsBboxes, retImgsConfs + + +def detect(im, sess, net): + image_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + image_np_expanded = np.expand_dims(image_np, axis=0) + image_tensor = net.get_tensor_by_name('image_tensor:0') + bboxes = net.get_tensor_by_name('detection_boxes:0') + dConfs = net.get_tensor_by_name('detection_scores:0') + classes = net.get_tensor_by_name('detection_classes:0') + num_detections = net.get_tensor_by_name('num_detections:0') + (bboxes, dConfs, classes, + num_detections) = sess.run([bboxes, dConfs, classes, num_detections], + feed_dict={image_tensor: image_np_expanded}) + w, h, _ = im.shape + bboxes, confs = filter_bboxes_confs([w, h], bboxes[0], dConfs[0], True) + return bboxes, confs + + +class FaceDetector: + + def __init__(self, mod): + self.sess, self.net = init(mod) + + def do_detect(self, im): + bboxes, confs = detect(im, self.sess, self.net) + return bboxes, confs diff --git a/modelscope/models/cv/face_emotion/face_alignment/face_align.py b/modelscope/models/cv/face_emotion/face_alignment/face_align.py new file mode 100644 index 00000000..71282b12 --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face_align.py @@ -0,0 +1,59 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import cv2 +import numpy as np +from PIL import Image, ImageFile + +from .face import FaceDetector + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def adjust_bx_v2(box, w, h): + x1, y1, x2, y2 = box[0], box[1], box[2], box[3] + box_w = x2 - x1 + box_h = y2 - y1 + delta = abs(box_w - box_h) + if box_w > box_h: + if y1 >= delta: + y1 = y1 - delta + else: + delta_y1 = y1 + y1 = 0 + delta_y2 = delta - delta_y1 + y2 = y2 + delta_y2 if y2 < h - delta_y2 else h - 1 + else: + if x1 >= delta / 2 and x2 <= w - delta / 2: + x1 = x1 - delta / 2 + x2 = x2 + delta / 2 + elif x1 < delta / 2 and x2 <= w - delta / 2: + delta_x1 = x1 + x1 = 0 + delta_x2 = delta - delta_x1 + x2 = x2 + delta_x2 if x2 < w - delta_x2 else w - 1 + elif x1 >= delta / 2 and x2 > w - delta / 2: + delta_x2 = w - x2 + x2 = w - 1 + delta_x1 = delta - x1 + x1 = x1 - delta_x1 if x1 >= delta_x1 else 0 + + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + return [x1, y1, x2, y2] + + +def face_detection_PIL_v2(image, face_model): + crop_size = 112 + face_detector = FaceDetector(face_model) + img = np.array(image) + h, w = img.shape[0:2] + bxs, conf = face_detector.do_detect(img) + bx = bxs[0] + bx = adjust_bx_v2(bx, w, h) + x1, y1, x2, y2 = bx + image = img[y1:y2, x1:x2, :] + img = Image.fromarray(image) + img = img.resize((crop_size, crop_size)) + bx = tuple(bx) + return img, bx diff --git a/modelscope/outputs.py b/modelscope/outputs.py index f13bbed9..9811811e 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -664,11 +664,15 @@ TASK_OUTPUTS = { # } Tasks.hand_static: [OutputKeys.OUTPUT], - # { # 'output': [ # [2, 75, 287, 240, 510, 0.8335018754005432], # [1, 127, 83, 332, 366, 0.9175254702568054], # [0, 0, 0, 367, 639, 0.9693422317504883]] # } Tasks.face_human_hand_detection: [OutputKeys.OUTPUT], + + # { + # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} + # } + Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES] } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index a14b07a6..ff56658f 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -186,6 +186,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.face_human_hand_detection: (Pipelines.face_human_hand_detection, 'damo/cv_nanodet_face-human-hand-detection'), + Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), } diff --git a/modelscope/pipelines/cv/face_emotion_pipeline.py b/modelscope/pipelines/cv/face_emotion_pipeline.py new file mode 100644 index 00000000..249493b6 --- /dev/null +++ b/modelscope/pipelines/cv/face_emotion_pipeline.py @@ -0,0 +1,39 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_emotion import emotion_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_emotion, module_name=Pipelines.face_emotion) +class FaceEmotionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face emotion pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + self.face_model = model + '/' + ModelFile.TF_GRAPH_FILE + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result, bbox = emotion_infer.inference(input['img_path'], self.model, + self.face_model) + return {OutputKeys.OUTPUT: result, OutputKeys.BOXES: bbox} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index ac6846e4..4aff1d05 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -44,6 +44,7 @@ class CVTasks(object): shop_segmentation = 'shop-segmentation' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_face_emotion.py b/tests/pipelines/test_face_emotion.py new file mode 100644 index 00000000..907e15ee --- /dev/null +++ b/tests/pipelines/test_face_emotion.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaceEmotionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_face-emotion' + self.img = {'img_path': 'data/test/images/face_emotion.jpg'} + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_emotion = pipeline(Tasks.face_emotion, model=self.model) + self.pipeline_inference(face_emotion, self.img) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_emotion = pipeline(Tasks.face_emotion) + self.pipeline_inference(face_emotion, self.img) + + +if __name__ == '__main__': + unittest.main()